Shortcuts
Open in Colab Open on Planetary Computer

Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

PyTorch Lightning trainers

In this tutorial, we demonstrate TorchGeo trainers to train and test a model. Specifically, we use the Tropical Cyclone dataset and train models to predict cyclone windspeed given imagery of the cyclone.

It’s recommended to run this notebook on Google Colab if you don’t have your own GPU. Click the “Open in Colab” button above to get started.

Setup

First, we install TorchGeo.

[1]:
%pip install torchgeo[datasets]

Imports

Next, we import TorchGeo and any other libraries we need.

[2]:
%matplotlib inline
import os
import csv
import tempfile

import numpy as np
import matplotlib.pyplot as plt

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from torchgeo.datamodules import TropicalCycloneDataModule
from torchgeo.trainers import RegressionTask
[3]:
# we set a flag to check to see whether the notebook is currently being run by PyTest, if this is the case then we'll
# skip the expensive training.
in_tests = "PYTEST_CURRENT_TEST" in os.environ

Lightning modules

Our trainers use PyTorch Lightning to organize both the training code, and the dataloader setup code. This makes it easy to create and share reproducible experiments and results.

First we’ll create a TropicalCycloneDataModule object which is simply a wrapper around the TropicalCyclone dataset. This object 1.) ensures that the data is downloaded*, 2.) sets up PyTorch DataLoader objects for the train, validation, and test splits, and 3.) ensures that data from the same cyclone is not shared between the training and validation sets so that you can properly evaluate the generalization performance of your model.

*To automatically download the dataset, you need an API key from the Radiant Earth MLHub. This is completely free, and will give you access to a growing catalog of ML-ready remote sensing datasets.

[4]:
# Set this to your API key (available for free at https://mlhub.earth/)
MLHUB_API_KEY = os.environ["MLHUB_API_KEY"]
[5]:
data_dir = os.path.join(tempfile.gettempdir(), "cyclone_data")

datamodule = TropicalCycloneDataModule(
    root=data_dir, seed=1337, batch_size=64, num_workers=6, api_key=MLHUB_API_KEY
)

Next, we create a RegressionTask object that holds the model object, optimizer object, and training logic.

[6]:
task = RegressionTask(
    model="resnet18",
    pretrained=True,
    learning_rate=0.1,
    learning_rate_schedule_patience=5,
)

Training

Now that we have the Lightning modules set up, we can use a PyTorch Lightning Trainer to run the the training and evaluation loops. There are many useful pieces of configuration that can be set in the Trainer – below we set up model checkpointing based on the validation loss, early stopping based on the validation loss, and a CSV based logger. We encourage you to see the PyTorch Lightning docs for other options that can be set here, e.g. Tensorboard logging, automatically selecting your optimizer’s learning rate, and easy multi-GPU training.

[7]:
experiment_dir = os.path.join(tempfile.gettempdir(), "cyclone_results")

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss", dirpath=experiment_dir, save_top_k=1, save_last=True
)

early_stopping_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=10)

csv_logger = CSVLogger(save_dir=experiment_dir, name="tutorial_logs")

For tutorial purposes we deliberately lower the maximum number of training epochs.

[8]:
trainer = pl.Trainer(
    callbacks=[checkpoint_callback, early_stopping_callback],
    logger=[csv_logger],
    default_root_dir=experiment_dir,
    min_epochs=1,
    max_epochs=10,
    fast_dev_run=in_tests,
)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

When we first call .fit(...) the dataset will be downloaded and checksummed (if it hasn’t already). This can take 5–10 minutes. After this, the training process will kick off, and results will be saved to a CSV file.

[9]:
trainer.fit(model=task, datamodule=datamodule)
Files already downloaded and verified
Files already downloaded and verified
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 11.2 M
---------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.708    Total estimated model params size (MB)
/anaconda/envs/torchgeo/lib/python3.9/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /tmp/pip-req-build-19kunu9c/c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)

We load the log files and plot the training RMSE over batches, and the validation RMSE over epochs. We can see that our model is just starting to converge, and would probably benefit from additional training time and a lower initial learning rate.

[10]:
if not in_tests:
    train_steps = []
    train_rmse = []

    val_steps = []
    val_rmse = []
    with open(
        os.path.join(experiment_dir, "tutorial_logs", "version_0", "metrics.csv"), "r"
    ) as f:
        csv_reader = csv.DictReader(f, delimiter=",")
        for i, row in enumerate(csv_reader):
            try:
                train_rmse.append(float(row["train_RMSE"]))
                train_steps.append(i)
            except ValueError:  # Ignore rows where train RMSE is empty
                pass

            try:
                val_rmse.append(float(row["val_RMSE"]))
                val_steps.append(i)
            except ValueError:  # Ignore rows where val RMSE is empty
                pass
[11]:
if not in_tests:
    plt.figure()
    plt.plot(train_steps, train_rmse, label="Train RMSE")
    plt.plot(val_steps, val_rmse, label="Validation RMSE")
    plt.legend(fontsize=15)
    plt.xlabel("Batches", fontsize=15)
    plt.ylabel("RMSE", fontsize=15)
    plt.show()
    plt.close()
../_images/tutorials_trainers_20_0.png

Finally, after the model has been trained, we can easily evaluate it on the test set.

[12]:
trainer.test(model=task, datamodule=datamodule)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 227.86709594726562, 'test_rmse': 13.240558624267578}
--------------------------------------------------------------------------------
[12]:
[{'test_loss': 227.86709594726562, 'test_rmse': 13.240558624267578}]
Read the Docs v: latest
Versions
latest
stable
v0.3.1
v0.3.0
v0.2.1
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources