Source code for torchgeo.datamodules.sen12ms

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""SEN12MS datamodule."""

from typing import Any, Dict

import torch
from sklearn.model_selection import GroupShuffleSplit
from torch import Tensor
from import Subset

from ..datasets import SEN12MS
from .geo import NonGeoDataModule

[docs]class SEN12MSDataModule(NonGeoDataModule): """LightningDataModule implementation for the SEN12MS dataset. Implements 80/20 geographic train/val splits and uses the test split from the classification dataset definitions. Uses the Simplified IGBP scheme defined in the 2020 Data Fusion Competition. See """ #: Mapping from the IGBP class definitions to the DFC2020, taken from the dataloader #: here: DFC2020_CLASS_MAPPING = torch.tensor( [0, 1, 1, 1, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 6, 8, 9, 10] ) std = torch.tensor( [-25, -25, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4] )
[docs] def __init__( self, batch_size: int = 64, num_workers: int = 0, band_set: str = "all", **kwargs: Any, ) -> None: """Initialize a new SEN12MSDataModule instance. Args: batch_size: Size of each mini-batch. num_workers: Number of workers for parallel data loading. band_set: Subset of S1/S2 bands to use. Options are: "all", "s1", "s2-all", and "s2-reduced" where the "s2-reduced" set includes: B2, B3, B4, B8, B11, and B12. **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.SEN12MS`. """ kwargs["bands"] = SEN12MS.BAND_SETS[band_set] if band_set == "s1": self.std = self.std[:2] elif band_set == "s2-all": self.std = self.std[2:] elif band_set == "s2-reduced": self.std = self.std[torch.tensor([3, 4, 5, 9, 12, 13])] super().__init__(SEN12MS, batch_size, num_workers, **kwargs)
[docs] def setup(self, stage: str) -> None: """Set up datasets. Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ if stage in ["fit", "validate"]: season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000} self.dataset = SEN12MS(split="train", **self.kwargs) # A patch is a filename like: # "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif" # This patch will belong to the scene that is uniquely identified by its # (season, scene_id) tuple. Because the largest scene_id is 149, we can # simply give each season a large number and representing a unique_scene_id # as (season_id + scene_id). scenes = [] for scene_fn in self.dataset.ids: parts = scene_fn.split("_") season_id = season_to_int[parts[1]] scene_id = int(parts[3]) scenes.append(season_id + scene_id) train_indices, val_indices = next( GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=0).split( scenes, groups=scenes ) ) self.train_dataset = Subset(self.dataset, train_indices) self.val_dataset = Subset(self.dataset, val_indices) if stage in ["test"]: self.test_dataset = SEN12MS(split="test", **self.kwargs)
[docs] def on_after_batch_transfer( self, batch: Dict[str, Tensor], dataloader_idx: int ) -> Dict[str, Tensor]: """Apply batch augmentations to the batch after it is transferred to the device. Args: batch: A batch of data that needs to be altered or augmented. dataloader_idx: The index of the dataloader to which the batch belongs. Returns: A batch of data. """ batch["mask"] = torch.take(self.DFC2020_CLASS_MAPPING, batch["mask"]) return super().on_after_batch_transfer(batch, dataloader_idx)

© Copyright 2021, Microsoft Corporation. Revision d2d0e231.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources