Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
In this tutorial, we demonstrate some available pretrained weights in TorchGeo. The implementation follows torchvisions’ recently introduced Multi-Weight API. We will use the EuroSAT dataset throughout this tutorial.
It’s recommended to run this notebook on Google Colab if you don’t have your own GPU. Click the “Open in Colab” button above to get started.
First, we install TorchGeo.
%pip install torchgeo
Next, we import TorchGeo and any other libraries we need.
%matplotlib inline import os import csv import tempfile import matplotlib.pyplot as plt import numpy as np import pytorch_lightning as pl import timm from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import CSVLogger from torchgeo.datamodules import EuroSATDataModule from torchgeo.trainers import ClassificationTask from torchgeo.models import ResNet50_Weights, ViTSmall16_Weights
# we set a flag to check to see whether the notebook is currently being run by PyTest, if this is the case then we'll # skip the expensive training. in_tests = "PYTEST_CURRENT_TEST" in os.environ
We will utilize TorchGeo’s datamodules from PyTorch Lightning to organize the dataloader setup.
root = os.path.join(tempfile.gettempdir(), "eurosat") datamodule = EuroSATDataModule(root=root, batch_size=64, num_workers=4, download=True)
Available pretrained weights are listed on the model documentation page. While some weights only accept RGB channel input, some weights have been pretrained on Sentinel 2 imagery with 13 input channels and can hence prove useful for transfer learning tasks involving Sentinel 2 data.
To access these weights you can do the following:
weights = ResNet50_Weights.SENTINEL2_ALL_MOCO
This set of weights is a torchvision
WeightEnum and holds information such as the download url link or additional meta data. TorchGeo takes care of the downloading and initialization of models with a desired set of weights. Given that EuroSAT is a classification dataset, we can use a
ClassificationTask object that holds the model and optimizer object as well as the training logic.
task = ClassificationTask( model="resnet50", loss="ce", weights=weights, in_channels=13, num_classes=10, learning_rate=0.001, learning_rate_schedule_patience=5, )
If you do not want to utilize the
ClassificationTask functionality for your experiments, you can also just create a timm model with pretrained weights from TorchGeo as follows:
in_chans = weights.meta["in_chans"] model = timm.create_model("resnet50", in_chans=in_chans, num_classes=10) model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
To train our pretrained model on the EuroSAT dataset we will make use of PyTorch Lightning’s Trainer. For a more elaborate explanation of how TorchGeo uses PyTorch Lightning, check out this tutorial.
experiment_dir = os.path.join(tempfile.gettempdir(), "eurosat_results") checkpoint_callback = ModelCheckpoint( monitor="val_loss", dirpath=experiment_dir, save_top_k=1, save_last=True ) early_stopping_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=10) csv_logger = CSVLogger(save_dir=experiment_dir, name="pretrained_weights_logs")
trainer = pl.Trainer( callbacks=[checkpoint_callback, early_stopping_callback], logger=[csv_logger], default_root_dir=experiment_dir, min_epochs=1, max_epochs=10, fast_dev_run=in_tests, )