Source code for torchgeo.datamodules.sen12ms
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""SEN12MS datamodule."""
from typing import Any, Dict
import torch
from sklearn.model_selection import GroupShuffleSplit
from torch import Tensor
from torch.utils.data import Subset
from ..datasets import SEN12MS
from .geo import NonGeoDataModule
[docs]class SEN12MSDataModule(NonGeoDataModule):
"""LightningDataModule implementation for the SEN12MS dataset.
Implements 80/20 geographic train/val splits and uses the test split from the
classification dataset definitions.
Uses the Simplified IGBP scheme defined in the 2020 Data Fusion Competition. See
https://arxiv.org/abs/2002.08254.
"""
#: Mapping from the IGBP class definitions to the DFC2020, taken from the dataloader
#: here: https://github.com/lukasliebel/dfc2020_baseline.
DFC2020_CLASS_MAPPING = torch.tensor(
[0, 1, 1, 1, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 6, 8, 9, 10]
)
std = torch.tensor(
[-25, -25, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4]
)
[docs] def __init__(
self,
batch_size: int = 64,
num_workers: int = 0,
band_set: str = "all",
**kwargs: Any,
) -> None:
"""Initialize a new SEN12MSDataModule instance.
Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
band_set: Subset of S1/S2 bands to use. Options are: "all",
"s1", "s2-all", and "s2-reduced" where the "s2-reduced" set includes:
B2, B3, B4, B8, B11, and B12.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.SEN12MS`.
"""
kwargs["bands"] = SEN12MS.BAND_SETS[band_set]
if band_set == "s1":
self.std = self.std[:2]
elif band_set == "s2-all":
self.std = self.std[2:]
elif band_set == "s2-reduced":
self.std = self.std[torch.tensor([3, 4, 5, 9, 12, 13])]
super().__init__(SEN12MS, batch_size, num_workers, **kwargs)
[docs] def setup(self, stage: str) -> None:
"""Set up datasets.
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
if stage in ["fit", "validate"]:
season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000}
self.dataset = SEN12MS(split="train", **self.kwargs)
# A patch is a filename like:
# "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif"
# This patch will belong to the scene that is uniquely identified by its
# (season, scene_id) tuple. Because the largest scene_id is 149, we can
# simply give each season a large number and representing a unique_scene_id
# as (season_id + scene_id).
scenes = []
for scene_fn in self.dataset.ids:
parts = scene_fn.split("_")
season_id = season_to_int[parts[1]]
scene_id = int(parts[3])
scenes.append(season_id + scene_id)
train_indices, val_indices = next(
GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=0).split(
scenes, groups=scenes
)
)
self.train_dataset = Subset(self.dataset, train_indices)
self.val_dataset = Subset(self.dataset, val_indices)
if stage in ["test"]:
self.test_dataset = SEN12MS(split="test", **self.kwargs)
[docs] def on_after_batch_transfer(
self, batch: Dict[str, Tensor], dataloader_idx: int
) -> Dict[str, Tensor]:
"""Apply batch augmentations to the batch after it is transferred to the device.
Args:
batch: A batch of data that needs to be altered or augmented.
dataloader_idx: The index of the dataloader to which the batch belongs.
Returns:
A batch of data.
"""
batch["mask"] = torch.take(self.DFC2020_CLASS_MAPPING, batch["mask"])
return super().on_after_batch_transfer(batch, dataloader_idx)