Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Lightning Trainers¶
In this tutorial, we demonstrate TorchGeo trainers to train and test a model. Specifically, we use the Tropical Cyclone dataset and train models to predict cyclone wind speed given imagery of the cyclone.
It’s recommended to run this notebook on Google Colab if you don’t have your own GPU. Click the “Open in Colab” button above to get started.
Setup¶
First, we install TorchGeo.
[1]:
%pip install torchgeo[datasets]
Imports¶
Next, we import TorchGeo and any other libraries we need.
[2]:
%matplotlib inline
import os
import csv
import tempfile
import numpy as np
import matplotlib.pyplot as plt
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from torchgeo.datamodules import TropicalCycloneDataModule
from torchgeo.trainers import RegressionTask
[3]:
# we set a flag to check to see whether the notebook is currently being run by PyTest, if this is the case then we'll
# skip the expensive training.
in_tests = "PYTEST_CURRENT_TEST" in os.environ
Lightning modules¶
Our trainers use Lightning to organize both the training code, and the dataloader setup code. This makes it easy to create and share reproducible experiments and results.
First we’ll create a TropicalCycloneDataModule
object which is simply a wrapper around the TropicalCyclone dataset. This object 1.) ensures that the data is downloaded*, 2.) sets up PyTorch DataLoader
objects for the train, validation, and test splits, and 3.) ensures that data from the same cyclone is not shared between the training and validation sets so that you can properly evaluate the
generalization performance of your model.
*To automatically download the dataset, you need an API key from the Radiant Earth MLHub. This is completely free, and will give you access to a growing catalog of ML-ready remote sensing datasets.
[4]:
# Set this to your API key (available for free at https://mlhub.earth/)
MLHUB_API_KEY = os.environ["MLHUB_API_KEY"]
[5]:
data_dir = os.path.join(tempfile.gettempdir(), "cyclone_data")
datamodule = TropicalCycloneDataModule(
root=data_dir, batch_size=64, num_workers=6, download=True, api_key=MLHUB_API_KEY
)
Next, we create a RegressionTask
object that holds the model object, optimizer object, and training logic.
[6]:
task = RegressionTask(
model="resnet18",
weights=True,
in_channels=3,
num_outputs=1,
learning_rate=0.1,
learning_rate_schedule_patience=5,
)
Training¶
Now that we have the Lightning modules set up, we can use a Lightning Trainer to run the training and evaluation loops. There are many useful pieces of configuration that can be set in the Trainer
– below we set up model checkpointing based on the validation loss, early stopping based on the validation loss, and a CSV based logger. We encourage you to see the Lightning docs for
other options that can be set here, e.g. Tensorboard logging, automatically selecting your optimizer’s learning rate, and easy multi-GPU training.
[7]:
experiment_dir = os.path.join(tempfile.gettempdir(), "cyclone_results")
checkpoint_callback = ModelCheckpoint(
monitor="val_loss", dirpath=experiment_dir, save_top_k=1, save_last=True
)
early_stopping_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=10)
csv_logger = CSVLogger(save_dir=experiment_dir, name="tutorial_logs")
For tutorial purposes we deliberately lower the maximum number of training epochs.
[8]:
trainer = Trainer(
callbacks=[checkpoint_callback, early_stopping_callback],
logger=[csv_logger],
default_root_dir=experiment_dir,
min_epochs=1,
max_epochs=10,
fast_dev_run=in_tests,
)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
When we first call .fit(...)
the dataset will be downloaded and checksummed (if it hasn’t already). This can take 5–10 minutes. After this, the training process will kick off, and results will be saved to a CSV file.
[9]:
trainer.fit(model=task, datamodule=datamodule)
Files already downloaded and verified
Files already downloaded and verified
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
---------------------------------
0 | model | ResNet | 11.2 M
---------------------------------
11.2 M Trainable params
0 Non-trainable params
11.2 M Total params
44.708 Total estimated model params size (MB)
/anaconda/envs/torchgeo/lib/python3.9/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /tmp/pip-req-build-19kunu9c/c10/core/TensorImpl.h:1156.)
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
We load the log files and plot the training RMSE over batches, and the validation RMSE over epochs. We can see that our model is just starting to converge, and would probably benefit from additional training time and a lower initial learning rate.
[10]:
if not in_tests:
train_steps = []
train_rmse = []
val_steps = []
val_rmse = []
with open(
os.path.join(experiment_dir, "tutorial_logs", "version_0", "metrics.csv"), "r"
) as f:
csv_reader = csv.DictReader(f, delimiter=",")
for i, row in enumerate(csv_reader):
try:
train_rmse.append(float(row["train_RMSE"]))
train_steps.append(i)
except ValueError: # Ignore rows where train RMSE is empty
pass
try:
val_rmse.append(float(row["val_RMSE"]))
val_steps.append(i)
except ValueError: # Ignore rows where val RMSE is empty
pass
[11]:
if not in_tests:
plt.figure()
plt.plot(train_steps, train_rmse, label="Train RMSE")
plt.plot(val_steps, val_rmse, label="Validation RMSE")
plt.legend(fontsize=15)
plt.xlabel("Batches", fontsize=15)
plt.ylabel("RMSE", fontsize=15)
plt.show()
plt.close()

Finally, after the model has been trained, we can easily evaluate it on the test set.
[12]:
trainer.test(model=task, datamodule=datamodule)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 227.86709594726562, 'test_rmse': 13.240558624267578}
--------------------------------------------------------------------------------
[12]:
[{'test_loss': 227.86709594726562, 'test_rmse': 13.240558624267578}]