Shortcuts
Open in Studio Open in Colab
[ ]:
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

Pretrained Weights

Written by: Nils Lehmann

In this tutorial, we demonstrate some available pretrained weights in TorchGeo. The implementation follows torchvisions’ recently introduced Multi-Weight API. We will use the EuroSAT dataset throughout this tutorial. Specifically, a subset containing only 100 images.

It’s recommended to run this notebook on Google Colab if you don’t have your own GPU. Click the “Open in Colab” button above to get started.

Setup

First, we install TorchGeo.

[ ]:
%pip install torchgeo

Imports

Next, we import TorchGeo and any other libraries we need.

[ ]:
%matplotlib inline

import os
import tempfile

import timm
import torch
from lightning.pytorch import Trainer

from torchgeo.datamodules import EuroSAT100DataModule
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import ClassificationTask

The following variables can be used to control training.

[ ]:
batch_size = 10
num_workers = 2
max_epochs = 10
fast_dev_run = False

Datamodule

We will utilize TorchGeo’s Lightning datamodules to organize the dataloader setup.

[ ]:
root = os.path.join(tempfile.gettempdir(), 'eurosat100')
datamodule = EuroSAT100DataModule(
    root=root, batch_size=batch_size, num_workers=num_workers, download=True
)

Weights

Pretrained weights for torchgeo.models are available and sorted by satellite or sensor type: sensor-agnostic, Landsat, NAIP, Sentinel-1, and Sentinel-2. Refer to the model documentation for a complete list of weights. Choose from the provided pre-trained weights based on your specific use case.

While some weights only accept RGB channel input, some weights have been pretrained on Sentinel-2 imagery with 13 input channels and can hence prove useful for transfer learning tasks involving Sentinel-2 data.

To use these weights, you can load them as follows:

[ ]:
weights = ResNet18_Weights.SENTINEL2_ALL_MOCO

This set of weights is a torchvision WeightEnum and holds information such as the download url link or additional meta data. TorchGeo takes care of the downloading and initialization of models with a desired set of weights.

torchgeo.trainers provides specialized task classes that simplify training workflows for common geospatial tasks. Depending on your objective, you can select the appropriate trainer class, such as ClassificationTask for classification, SemanticSegmentationTask for semantic segmentation, or other task-specific trainers. Check the trainers documentation for more information.

Given that EuroSAT is a classification dataset, we can use a ClassificationTask object that holds the model and optimizer as well as the training logic.

[ ]:
task = ClassificationTask(
    model='resnet18',
    loss='ce',
    weights=weights,
    in_channels=13,
    num_classes=10,
    lr=0.001,
    patience=5,
)

If you do not want to utilize the ClassificationTask functionality for your experiments, you can also just create a timm model with pretrained weights from TorchGeo as follows:

[ ]:
in_chans = weights.meta['in_chans']
model = timm.create_model('resnet18', in_chans=in_chans, num_classes=10)
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)

Training

To train our pretrained model on the EuroSAT dataset we will make use of Lightning’s Trainer. For a more elaborate explanation of how TorchGeo uses Lightning, check out this tutorial.

[ ]:
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
default_root_dir = os.path.join(tempfile.gettempdir(), 'experiments')
[ ]:
trainer = Trainer(
    accelerator=accelerator,
    default_root_dir=default_root_dir,
    fast_dev_run=fast_dev_run,
    log_every_n_steps=1,
    min_epochs=1,
    max_epochs=max_epochs,
)
[ ]:
trainer.fit(model=task, datamodule=datamodule)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources