Shortcuts

Source code for torchgeo.datamodules.sen12ms

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""SEN12MS datamodule."""

from typing import Any, Dict, Optional

import pytorch_lightning as pl
import torch
from sklearn.model_selection import GroupShuffleSplit
from torch.utils.data import DataLoader, Subset

from ..datasets import SEN12MS

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class SEN12MSDataModule(pl.LightningDataModule):
    """LightningDataModule implementation for the SEN12MS dataset.

    Implements 80/20 geographic train/val splits and uses the test split from the
    classification dataset definitions. See :func:`setup` for more details.

    Uses the Simplified IGBP scheme defined in the 2020 Data Fusion Competition. See
    https://arxiv.org/abs/2002.08254.
    """

    #: Mapping from the IGBP class definitions to the DFC2020, taken from the dataloader
    #: here https://github.com/lukasliebel/dfc2020_baseline.
    DFC2020_CLASS_MAPPING = torch.tensor(  # type: ignore[attr-defined]
        [
            0,  # maps 0s to 0
            1,  # maps 1s to 1
            1,  # maps 2s to 1
            1,  # ...
            1,
            1,
            2,
            2,
            3,
            3,
            4,
            5,
            6,
            7,
            6,
            8,
            9,
            10,
        ]
    )

[docs] def __init__( self, root_dir: str, seed: int, band_set: str = "all", batch_size: int = 64, num_workers: int = 0, **kwargs: Any, ) -> None: """Initialize a LightningDataModule for SEN12MS based DataLoaders. Args: root_dir: The ``root`` arugment to pass to the SEN12MS Dataset classes seed: The seed value to use when doing the sklearn based ShuffleSplit band_set: The subset of S1/S2 bands to use. Options are: "all", "s1", "s2-all", and "s2-reduced" where the "s2-reduced" set includes: B2, B3, B4, B8, B11, and B12. batch_size: The batch size to use in all created DataLoaders num_workers: The number of workers to use in all created DataLoaders """ super().__init__() # type: ignore[no-untyped-call] assert band_set in SEN12MS.BAND_SETS.keys() self.root_dir = root_dir self.seed = seed self.band_set = band_set self.band_indices = SEN12MS.BAND_SETS[band_set] self.batch_size = batch_size self.num_workers = num_workers
[docs] def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset. Args: sample: dictionary containing image and mask Returns: preprocessed sample """ sample["image"] = sample["image"].float() if self.band_set == "all": sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25 sample["image"][2:] = sample["image"][2:].clamp(0, 10000) / 10000 elif self.band_set == "s1": sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25 else: sample["image"][:] = sample["image"][:].clamp(0, 10000) / 10000 sample["mask"] = sample["mask"][0, :, :].long() sample["mask"] = torch.take( # type: ignore[attr-defined] self.DFC2020_CLASS_MAPPING, sample["mask"] ) return sample
[docs] def setup(self, stage: Optional[str] = None) -> None: """Create the train/val/test splits based on the original Dataset objects. The splits should be done here vs. in :func:`__init__` per the docs: https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. We split samples between train and val geographically with proportions of 80/20. This mimics the geographic test set split. Args: stage: stage to set up """ season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000} self.all_train_dataset = SEN12MS( self.root_dir, split="train", bands=self.band_indices, transforms=self.custom_transform, checksum=False, ) self.all_test_dataset = SEN12MS( self.root_dir, split="test", bands=self.band_indices, transforms=self.custom_transform, checksum=False, ) # A patch is a filename like: "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif" # This patch will belong to the scene that is uniquelly identified by its # (season, scene_id) tuple. Because the largest scene_id is 149, we can simply # give each season a large number and representing a `unique_scene_id` as # `season_id + scene_id`. scenes = [] for scene_fn in self.all_train_dataset.ids: parts = scene_fn.split("_") season_id = season_to_int[parts[1]] scene_id = int(parts[3]) scenes.append(season_id + scene_id) train_indices, val_indices = next( GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split( scenes, groups=scenes ) ) self.train_dataset = Subset(self.all_train_dataset, train_indices) self.val_dataset = Subset(self.all_train_dataset, val_indices) self.test_dataset = Subset( self.all_test_dataset, range(len(self.all_test_dataset)) )
[docs] def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training. Returns: training data loader """ return DataLoader( self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, )
[docs] def val_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for validation. Returns: validation data loader """ return DataLoader( self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, )
[docs] def test_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for testing. Returns: testing data loader """ return DataLoader( self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, )

© Copyright 2021, Microsoft Corporation. Revision e1285e6c.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources