Shortcuts
Open in Studio Open in Colab
[ ]:
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

Transforms

Written by: Isaac A. Corley

In this tutorial, we demonstrate how to use TorchGeo’s data augmentation transforms and provide examples of how to utilize them in your experiments with multispectral imagery.

It’s recommended to run this notebook on Google Colab if you don’t have your own GPU. Click the “Open in Colab” button above to get started.

Setup

Install TorchGeo

[ ]:
%pip install torchgeo

Imports

[ ]:
import os
import tempfile

import kornia.augmentation as K
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader

from torchgeo.datasets import EuroSAT100
from torchgeo.transforms import indices

Custom Transforms

Here we create a transform to show an example of how you can chain custom operations along with TorchGeo and Kornia transforms/augmentations. Note how our transform takes as input a Dict of Tensors. We specify our data by the keys [“image”, “mask”, “label”, etc.] and follow this standard across TorchGeo datasets.

[ ]:
class MinMaxNormalize(K.IntensityAugmentationBase2D):
    """Normalize channels to the range [0, 1] using min/max values."""

    def __init__(self, mins: Tensor, maxs: Tensor) -> None:
        super().__init__(p=1)
        self.flags = {'mins': mins.view(1, -1, 1, 1), 'maxs': maxs.view(1, -1, 1, 1)}

    def apply_transform(
        self,
        input: Tensor,
        params: dict[str, Tensor],
        flags: dict[str, int],
        transform: Tensor | None = None,
    ) -> Tensor:
        return (input - flags['mins']) / (flags['maxs'] - flags['mins'] + 1e-10)

Dataset Bands and Statistics

Below we have min/max values calculated across the dataset per band. The values were clipped to the interval [0, 98] to stretch the band values and avoid outliers influencing the band histograms.

[ ]:
mins = torch.tensor(
    [
        1013.0,
        676.0,
        448.0,
        247.0,
        269.0,
        253.0,
        243.0,
        189.0,
        61.0,
        4.0,
        33.0,
        11.0,
        186.0,
    ]
)
maxs = torch.tensor(
    [
        2309.0,
        4543.05,
        4720.2,
        5293.05,
        3902.05,
        4473.0,
        5447.0,
        5948.05,
        1829.0,
        23.0,
        4894.05,
        4076.05,
        5846.0,
    ]
)
bands = {
    'B01': 'Coastal Aerosol',
    'B02': 'Blue',
    'B03': 'Green',
    'B04': 'Red',
    'B05': 'Vegetation Red Edge 1',
    'B06': 'Vegetation Red Edge 2',
    'B07': 'Vegetation Red Edge 3',
    'B08': 'NIR 1',
    'B09': 'Water Vapour',
    'B10': 'SWIR 1',
    'B11': 'SWIR 2',
    'B12': 'SWIR 3',
    'B8A': 'NIR 2',
}

The following variables can be used to control the dataloader.

[ ]:
batch_size = 4
num_workers = 2

Load the EuroSat MS dataset and dataloader

We will use the EuroSAT dataset throughout this tutorial. Specifically, a subset containing only 100 images.

[ ]:
root = os.path.join(tempfile.gettempdir(), 'eurosat100')
dataset = EuroSAT100(root, download=True)
dataloader = DataLoader(
    dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
dataloader = iter(dataloader)
print(f'Number of images in dataset: {len(dataset)}')
print(f'Dataset Classes: {dataset.classes}')

Load a sample and batch of images and labels

Here we test our dataset by loading a single image and label. Note how the image is of shape (13, 64, 64) containing a 64x64 shape with 13 multispectral bands.

[ ]:
sample = dataset[0]
x, y = sample['image'], sample['label']
print(x.shape, x.dtype, x.min(), x.max())
print(y, dataset.classes[y])

Here we test our dataloader by loading a single batch of images and labels. Note how the image is of shape (4, 13, 64, 64) containing 4 samples due to our batch_size.

[ ]:
batch = next(dataloader)
x, y = batch['image'], batch['label']
print(x.shape, x.dtype, x.min(), x.max())
print(y, [dataset.classes[i] for i in y])

Transforms Usage

torchgeo.transforms work seamlessly with both singular samples and batches of data. They can be applied within datasets or externally and combined with other transforms using nn.Sequential. Built for multispectral imagery, they are fully compatible with torchvision.transforms and kornia.augmentation.

[ ]:
transform = MinMaxNormalize(mins, maxs)
print(x.shape)
x = transform(x)
print(x.dtype, x.min(), x.max())

Appending Indices

torchgeo.transforms support appending indices to a specified channel dimension.

For detailed usage of all available transforms, refer to the transforms documentation.

The following example shows how indices can be computed on batches of images and appended as an additional band to the specified channel dimension. Notice how the number of channels increases from 13 -> 14.

[ ]:
transform = indices.AppendNDVI(index_nir=7, index_red=3)
batch = next(dataloader)
x = batch['image']
print(x.shape)
x = transform(x)
print(x.shape)

This makes it incredibly easy to add indices as additional features during training by chaining multiple Appends together.

[ ]:
transforms = nn.Sequential(
    MinMaxNormalize(mins, maxs),
    indices.AppendNDBI(index_swir=11, index_nir=7),
    indices.AppendNDSI(index_green=3, index_swir=11),
    indices.AppendNDVI(index_nir=7, index_red=3),
    indices.AppendNDWI(index_green=2, index_nir=7),
)

batch = next(dataloader)
x = batch['image']
print(x.shape)
x = transforms(x)
print(x.shape)

It’s even possible to chain indices along with augmentations from Kornia for a single callable during training.

When using Kornia with a dictionary input, you must explicitly set data_keys=None during the creation of the augmentation pipeline.

[ ]:
transforms = K.AugmentationSequential(
    MinMaxNormalize(mins, maxs),
    indices.AppendNDBI(index_swir=11, index_nir=7),
    indices.AppendNDSI(index_green=3, index_swir=11),
    indices.AppendNDVI(index_nir=7, index_red=3),
    indices.AppendNDWI(index_green=2, index_nir=7),
    K.RandomHorizontalFlip(p=0.5),
    K.RandomVerticalFlip(p=0.5),
    data_keys=None,
)

batch = next(dataloader)
print(batch['image'].shape)
batch = transforms(batch)
print(batch['image'].shape)

All of our transforms are nn.Modules. This allows us to push them and the data to the GPU to see significant gains for large scale operations.

[ ]:
!nvidia-smi
[ ]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

transforms = K.AugmentationSequential(
    MinMaxNormalize(mins, maxs),
    indices.AppendNDBI(index_swir=11, index_nir=7),
    indices.AppendNDSI(index_green=3, index_swir=11),
    indices.AppendNDVI(index_nir=7, index_red=3),
    indices.AppendNDWI(index_green=2, index_nir=7),
    K.RandomHorizontalFlip(p=0.5),
    K.RandomVerticalFlip(p=0.5),
    K.RandomAffine(degrees=(0, 90), p=0.25),
    K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),
    K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),
    data_keys=None,
)

transforms_gpu = K.AugmentationSequential(
    MinMaxNormalize(mins.to(device), maxs.to(device)),
    indices.AppendNDBI(index_swir=11, index_nir=7),
    indices.AppendNDSI(index_green=3, index_swir=11),
    indices.AppendNDVI(index_nir=7, index_red=3),
    indices.AppendNDWI(index_green=2, index_nir=7),
    K.RandomHorizontalFlip(p=0.5),
    K.RandomVerticalFlip(p=0.5),
    K.RandomAffine(degrees=(0, 90), p=0.25),
    K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),
    K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),
    data_keys=None,
).to(device)


def get_batch_cpu():
    return dict(image=torch.randn(64, 13, 512, 512).to('cpu'))


def get_batch_gpu():
    return dict(image=torch.randn(64, 13, 512, 512).to(device))
[ ]:
%%timeit -n 1 -r 5
_ = transforms(get_batch_cpu())
[ ]:
%%timeit -n 1 -r 5
_ = transforms_gpu(get_batch_gpu())

Visualize Images and Labels

This is a Google Colab browser for the EuroSAT dataset. Adjust the slider to visualize images in the dataset.

[ ]:
transforms = K.AugmentationSequential(MinMaxNormalize(mins, maxs), data_keys=None)
dataset = EuroSAT100(root, transforms=transforms)
[ ]:
# @title EuroSat Multispectral (MS) Browser  { run: "auto", vertical-output: true }
idx = 21  # @param {type:"slider", min:0, max:59, step:1}
sample = dataset[idx]
rgb = sample['image'][0, 1:4]
image = T.ToPILImage()(rgb)
print(f'Class Label: {dataset.classes[sample["label"]]}')
image.resize((256, 256), resample=Image.BILINEAR)

Additional Reading

To learn more about preprocessing and data augmentation transforms, the following external resources may be helpful:

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources