Shortcuts
Open in Colab

Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

Transforms

In this tutorial, we demonstrate how to use TorchGeo’s data augmentation transforms and provide examples of how to utilize them in your experiments with multispectral imagery.

It’s recommended to run this notebook on Google Colab if you don’t have your own GPU. Click the “Open in Colab” button above to get started.

Setup

Install TorchGeo

[ ]:
%pip install torchgeo

Imports

[16]:
from typing import Dict, List

import kornia.augmentation as K
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from torch.utils.data import DataLoader

from torchgeo.datasets import EuroSAT
from torchgeo.transforms import AugmentationSequential, indices

Custom Transforms

Here we create an transform to show an example of how you can chain custom operations along with TorchGeo and Kornia transforms/augmentations. Note how our transform takes as input a Dict of Tensors. We specify our data by the keys [“image”, “mask”, “label”, etc.] and follow this standard across TorchGeo datasets.

[2]:
class MinMaxNormalize(nn.Module):
    """Normalize channels to the range [0, 1] using min/max values."""

    def __init__(self, min: List[float], max: List[float]) -> None:
        super().__init__()
        self.min = torch.tensor(min)[:, None, None]
        self.max = torch.tensor(max)[:, None, None]
        self.denominator = (self.max - self.min)

    def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        x = inputs["image"]

        # Batch
        if x.ndim == 4:
          x = (x - self.min[None, ...]) / self.denominator[None, ...]
        # Sample
        else:
          x = (x - self.min) / self.denominator

        inputs["image"] = x.clamp(0, 1)
        return inputs

Dataset Bands and Statistics

Below we have min/max values calculated across the dataset per band. The values were clipped to the interval [0, 98] to stretch the band values and avoid outliers influencing the band histograms.

[3]:
mins = [1013.0, 676.0, 448.0, 247.0, 269.0, 253.0, 243.0, 189.0, 61.0, 4.0, 33.0, 11.0, 186.0]
maxs = [2309.0, 4543.05, 4720.2, 5293.05, 3902.05, 4473.0, 5447.0, 5948.05, 1829.0, 23.0, 4894.05, 4076.05, 5846.0]
bands = {
    "B1": "Coastal Aerosol",
    "B2": "Blue",
    "B3": "Green",
    "B4": "Red",
    "B5": "Vegetation Red Edge 1",
    "B6": "Vegetation Red Edge 2",
    "B7": "Vegetation Red Edge 3",
    "B8": "NIR 1",
    "B8A": "NIR 2",
    "B9": "Water Vapour",
    "B10": "SWIR 1",
    "B11": "SWIR 2",
    "B12": "SWIR 3"
}

Load the EuroSat MS dataset and dataloader

Here we load the EuroSat Multispectral (MS) dataset. The dataset contains 27,000 64x64 Sentinel-2 multispectral patches with 10 land cover classes.

[4]:
dataset = EuroSAT(download=True)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)
dataloader = iter(dataloader)
print(f"Number of images in dataset: {len(dataset)}")
print(f"Dataset Classes: {dataset.classes}")
Number of images in dataset: 27000
Dataset Classes: ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake']

Load a sample and batch of images and labels

Here we test our dataset by loading a single image and label. Note how the image is of shape (13, 64, 64) containing a 64x64 shape with 13 multispectral bands.

[16]:
sample = dataset[0]
x, y = sample["image"], sample["label"]
print(x.shape, x.dtype, x.min(), x.max())
print(y, dataset.classes[y])
torch.Size([13, 64, 64]) torch.int32 tensor(9, dtype=torch.int32) tensor(3490, dtype=torch.int32)
tensor(0) AnnualCrop

Here we test our dataloader by loading a single batch of images and labels. Note how the image is of shape (4, 13, 64, 64) containing 4 samples due to our batch_size.

[17]:
batch = next(dataloader)
x, y = batch["image"], batch["label"]
print(x.shape, x.dtype, x.min(), x.max())
print(y, [dataset.classes[i]for i in y])
torch.Size([4, 13, 64, 64]) torch.int32 tensor(6, dtype=torch.int32) tensor(4696, dtype=torch.int32)
tensor([6, 1, 8, 4]) ['PermanentCrop', 'Forest', 'River', 'Industrial']

Transforms Usage

Transforms are able to operate across batches of samples and singular samples. This allows them to be used inside the dataset itself or externally, chained together with other transform operations using nn.Sequential.

[19]:
transforms = MinMaxNormalize(mins, maxs)
print(batch["image"].shape)
batch = transforms(batch)
print(batch["image"].dtype, batch["image"].min(), batch["image"].max())
torch.Size([4, 13, 64, 64])
torch.float32 tensor(0.0079) tensor(0.9734)

Indices can also be computed on batches of images and appended as an additional band to the specified channel dimension. Notice how the number of channels increases from 13 -> 14.

[20]:
transform = indices.AppendNDVI(index_red=3, index_nir=7)
batch = next(dataloader)
print(batch["image"].shape)
batch = transform(batch)
print(batch["image"].shape)
torch.Size([4, 13, 64, 64])
torch.Size([4, 14, 64, 64])

This makes it incredibly easy to add indices as additional features during training by chaining multiple Appends together.

[21]:
transforms = nn.Sequential(
    MinMaxNormalize(mins, maxs),
    indices.AppendNDBI(index_swir=11, index_nir=7),
    indices.AppendNDSI(index_green=3, index_swir=11),
    indices.AppendNDVI(index_red=3, index_nir=7),
    indices.AppendNDWI(index_green=2, index_nir=7),
)

batch = next(dataloader)
print(batch["image"].shape)
batch = transforms(batch)
print(batch["image"].shape)
torch.Size([4, 13, 64, 64])
torch.Size([4, 17, 64, 64])

It’s even possible to chain indices along with augmentations from kornia for a single callable during training.

[34]:
augmentations = AugmentationSequential(
    K.RandomHorizontalFlip(p=0.5),
    K.RandomVerticalFlip(p=0.5),
    data_keys=["image"],
)
transforms = nn.Sequential(
    MinMaxNormalize(mins, maxs),
    indices.AppendNDBI(index_swir=11, index_nir=7),
    indices.AppendNDSI(index_green=3, index_swir=11),
    indices.AppendNDVI(index_red=3, index_nir=7),
    indices.AppendNDWI(index_green=2, index_nir=7),
    augmentations
)

batch = next(dataloader)
print(batch["image"].shape)
batch = transforms(batch)
print(batch["image"].shape)
torch.Size([4, 13, 64, 64])
torch.Size([4, 17, 64, 64])

All of our transforms are nn.Modules. This allows us to push them and the data to the GPU to see significant gains for large scale operations.

[8]:
%nvidia-smi
Tue Sep 28 20:52:49 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.63.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P8    27W / 149W |      3MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
[10]:
augmentations = AugmentationSequential(
    K.RandomHorizontalFlip(p=0.5),
    K.RandomVerticalFlip(p=0.5),
    K.RandomAffine(degrees=(0, 90), p=0.25),
    K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),
    K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),
    data_keys=["image"],
)
transforms = nn.Sequential(
    MinMaxNormalize(mins, maxs),
    indices.AppendNDBI(index_swir=11, index_nir=7),
    indices.AppendNDSI(index_green=3, index_swir=11),
    indices.AppendNDVI(index_red=3, index_nir=7),
    indices.AppendNDWI(index_green=2, index_nir=7),
    augmentations
)

device = "cpu" if torch.cuda.is_available() else "cuda"
batch_cpu = dict(image=torch.randn(64, 13, 512, 512).to("cpu"))
batch_gpu = dict(image=torch.randn(64, 13, 512, 512).to(device))

transforms_gpu = transforms.to(device)
[11]:
%%timeit -n 1 -r 1
_ = transforms(batch_cpu)
1 loop, best of 1: 12.3 s per loop
[12]:
%%timeit -n 1 -r 1
_ = transforms_gpu(batch_gpu)
1 loop, best of 1: 8.58 s per loop

Visualize Images and Labels

This is a Google Colab browser for the EuroSAT dataset. Adjust the slider to visualize images in the dataset.

[20]:
dataset = EuroSAT(transforms=MinMaxNormalize(mins, maxs))

#@title EuroSat Multispectral (MS) Browser  { run: "auto", vertical-output: true }
idx = 16199 #@param {type:"slider", min:0, max:16199, step:1}
sample = dataset[idx]
rgb = sample["image"][1:4]
rgb = rgb[[2, 1, 0], ...]
image = T.ToPILImage()(rgb)
print(f"Class Label: {dataset.classes[sample['label']]}")
image.resize((256, 256), resample=Image.BILINEAR)
Class Label: PermanentCrop
[20]:
../_images/tutorials_transforms_38_1.png
Read the Docs v: v0.1.1
Versions
latest
stable
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources