Shortcuts

Source code for torchgeo.datamodules.oscd

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""OSCD datamodule."""

from typing import Any, Dict, List, Optional, Tuple

import kornia.augmentation as K
import pytorch_lightning as pl
import torch
from einops import repeat
from torch.utils.data import DataLoader, Dataset
from torch.utils.data._utils.collate import default_collate
from torchvision.transforms import Compose, Normalize

from ..datasets import OSCD
from .utils import dataset_split


class OSCDDataModule(pl.LightningDataModule):
    """LightningDataModule implementation for the OSCD dataset.

    Uses the train/test splits from the dataset and further splits
    the train split into train/val splits.

    .. versionadded:: 0.2
    """

    band_means = torch.tensor(  # type: ignore[attr-defined]
        [
            1583.0741,
            1374.3202,
            1294.1616,
            1325.6158,
            1478.7408,
            1933.0822,
            2166.0608,
            2076.4868,
            2306.0652,
            690.9814,
            16.2360,
            2080.3347,
            1524.6930,
        ]
    )

    band_stds = torch.tensor(  # type: ignore[attr-defined]
        [
            52.1937,
            83.4168,
            105.6966,
            151.1401,
            147.4615,
            115.9289,
            123.1974,
            114.6483,
            141.4530,
            73.2758,
            4.8368,
            213.4821,
            179.4793,
        ]
    )

[docs] def __init__( self, root_dir: str, bands: str = "all", train_batch_size: int = 32, num_workers: int = 0, val_split_pct: float = 0.2, patch_size: Tuple[int, int] = (64, 64), num_patches_per_tile: int = 32, pad_size: Tuple[int, int] = (1280, 1280), **kwargs: Any, ) -> None: """Initialize a LightningDataModule for OSCD based DataLoaders. Args: root_dir: The ``root`` arugment to pass to the OSCD Dataset classes bands: "rgb" or "all" train_batch_size: The batch size used in the train DataLoader (val_batch_size == test_batch_size == 1) num_workers: The number of workers to use in all created DataLoaders val_split_pct: What percentage of the dataset to use as a validation set patch_size: Size of random patch from image and mask (height, width) num_patches_per_tile: number of random patches per sample pad_size: size to pad images to during val/test steps """ super().__init__() # type: ignore[no-untyped-call] self.root_dir = root_dir self.bands = bands self.train_batch_size = train_batch_size self.num_workers = num_workers self.val_split_pct = val_split_pct self.patch_size = patch_size self.num_patches_per_tile = num_patches_per_tile if bands == "rgb": self.band_means = self.band_means[[3, 2, 1], None, None] self.band_stds = self.band_stds[[3, 2, 1], None, None] else: self.band_means = self.band_means[:, None, None] self.band_stds = self.band_stds[:, None, None] self.norm = Normalize(self.band_means, self.band_stds) self.rcrop = K.AugmentationSequential( K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True ) self.padto = K.PadTo(pad_size)
[docs] def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset.""" sample["image"] = sample["image"].float() sample["mask"] = sample["mask"] sample["image"] = self.norm(sample["image"]) sample["image"] = torch.flatten( # type: ignore[attr-defined] sample["image"], 0, 1 ) return sample
[docs] def prepare_data(self) -> None: """Make sure that the dataset is downloaded. This method is only called once per run. """ OSCD(self.root_dir, split="train", bands=self.bands, checksum=False)
[docs] def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. This method is called once per GPU per run. """ def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]: images, masks = [], [] for i in range(self.num_patches_per_tile): mask = repeat(sample["mask"], "h w -> t h w", t=2).float() image, mask = self.rcrop(sample["image"], mask) mask = mask.squeeze()[0] images.append(image.squeeze()) masks.append(mask.long()) sample["image"] = torch.stack(images) sample["mask"] = torch.stack(masks) return sample def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]: sample["image"] = self.padto(sample["image"])[0] sample["mask"] = self.padto(sample["mask"].float()).long()[0, 0] return sample train_transforms = Compose([self.preprocess, n_random_crop]) # for testing and validation we pad all inputs to a fixed size to avoid issues # with the upsampling paths in encoder-decoder architectures test_transforms = Compose([self.preprocess, pad_to]) train_dataset = OSCD( self.root_dir, split="train", bands=self.bands, transforms=train_transforms ) self.train_dataset: Dataset[Any] self.val_dataset: Dataset[Any] if self.val_split_pct > 0.0: val_dataset = OSCD( self.root_dir, split="train", bands=self.bands, transforms=test_transforms, ) self.train_dataset, self.val_dataset, _ = dataset_split( train_dataset, val_pct=self.val_split_pct, test_pct=0.0 ) self.val_dataset.dataset = val_dataset else: self.train_dataset = train_dataset self.val_dataset = train_dataset self.test_dataset = OSCD( self.root_dir, split="test", bands=self.bands, transforms=test_transforms )
[docs] def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training.""" def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: r_batch: Dict[str, Any] = default_collate( # type: ignore[no-untyped-call] batch ) r_batch["image"] = torch.flatten( # type: ignore[attr-defined] r_batch["image"], 0, 1 ) r_batch["mask"] = torch.flatten( # type: ignore[attr-defined] r_batch["mask"], 0, 1 ) return r_batch return DataLoader( self.train_dataset, batch_size=self.train_batch_size, num_workers=self.num_workers, collate_fn=collate_wrapper, shuffle=True, )
[docs] def val_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for validation.""" return DataLoader( self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False )
[docs] def test_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for testing.""" return DataLoader( self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False )

© Copyright 2021, Microsoft Corporation. Revision e1285e6c.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources