Shortcuts

Source code for torchgeo.datamodules.chesapeake

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Chesapeake Bay High-Resolution Land Cover Project datamodule."""

from typing import Any

import kornia.augmentation as K
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor

from ..datasets import ChesapeakeCVPR
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..transforms import AugmentationSequential
from .geo import GeoDataModule


class _Transform(nn.Module):
    """Version of AugmentationSequential designed for samples, not batches."""

    def __init__(self, aug: nn.Module) -> None:
        """Initialize a new _Transform instance.

        Args:
            aug: Augmentation to apply.
        """
        super().__init__()
        self.aug = aug

    def forward(self, sample: dict[str, Any]) -> dict[str, Any]:
        """Apply the augmentation.

        Args:
            sample: Input sample.

        Returns:
            Augmented sample.
        """
        for key in ["image", "mask"]:
            dtype = sample[key].dtype
            # All inputs must be float
            sample[key] = sample[key].float()
            sample[key] = self.aug(sample[key])
            sample[key] = sample[key].to(dtype)
            # Kornia adds batch dimension
            sample[key] = rearrange(sample[key], "() c h w -> c h w")
        return sample


[docs]class ChesapeakeCVPRDataModule(GeoDataModule): """LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset. Uses the random splits defined per state to partition tiles into train, val, and test sets. """
[docs] def __init__( self, train_splits: list[str], val_splits: list[str], test_splits: list[str], batch_size: int = 64, patch_size: int = 256, length: int | None = None, num_workers: int = 0, class_set: int = 7, use_prior_labels: bool = False, prior_smoothing_constant: float = 1e-4, **kwargs: Any, ) -> None: """Initialize a new ChesapeakeCVPRDataModule instance. Args: train_splits: Splits used to train the model, e.g., ["ny-train"]. val_splits: Splits used to validate the model, e.g., ["ny-val"]. test_splits: Splits used to test the model, e.g., ["ny-test"]. batch_size: Size of each mini-batch. patch_size: Size of each patch, either ``size`` or ``(height, width)``. Should be a multiple of 32 for most segmentation architectures. length: Length of each training epoch. num_workers: Number of workers for parallel data loading. class_set: The high-resolution land cover class set to use (5 or 7). use_prior_labels: Flag for using a prior over high-resolution classes instead of the high-resolution labels themselves. prior_smoothing_constant: Additive smoothing to add when using prior labels. **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.ChesapeakeCVPR`. Raises: ValueError: If ``use_prior_labels=True`` is used with ``class_set=7``. """ # This is a rough estimate of how large of a patch we will need to sample in # EPSG:3857 in order to guarantee a large enough patch in the local CRS. self.original_patch_size = patch_size * 3 kwargs["transforms"] = _Transform(K.CenterCrop(patch_size)) super().__init__( ChesapeakeCVPR, batch_size, patch_size, length, num_workers, **kwargs ) assert class_set in [5, 7] if use_prior_labels and class_set == 7: raise ValueError( "The pre-generated prior labels are only valid for the 5" + " class set of labels" ) self.train_splits = train_splits self.val_splits = val_splits self.test_splits = test_splits self.class_set = class_set self.use_prior_labels = use_prior_labels self.prior_smoothing_constant = prior_smoothing_constant if self.use_prior_labels: self.layers = [ "naip-new", "prior_from_cooccurrences_101_31_no_osm_no_buildings", ] else: self.layers = ["naip-new", "lc"] self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"] )
[docs] def setup(self, stage: str) -> None: """Set up datasets and samplers. Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ if stage in ["fit"]: self.train_dataset = ChesapeakeCVPR( splits=self.train_splits, layers=self.layers, **self.kwargs ) self.train_batch_sampler = RandomBatchGeoSampler( self.train_dataset, self.original_patch_size, self.batch_size, self.length, ) if stage in ["fit", "validate"]: self.val_dataset = ChesapeakeCVPR( splits=self.val_splits, layers=self.layers, **self.kwargs ) self.val_sampler = GridGeoSampler( self.val_dataset, self.original_patch_size, self.original_patch_size ) if stage in ["test"]: self.test_dataset = ChesapeakeCVPR( splits=self.test_splits, layers=self.layers, **self.kwargs ) self.test_sampler = GridGeoSampler( self.test_dataset, self.original_patch_size, self.original_patch_size )
[docs] def on_after_batch_transfer( self, batch: dict[str, Tensor], dataloader_idx: int ) -> dict[str, Tensor]: """Apply batch augmentations to the batch after it is transferred to the device. Args: batch: A batch of data that needs to be altered or augmented. dataloader_idx: The index of the dataloader to which the batch belongs. Returns: A batch of data. """ if self.use_prior_labels: batch["mask"] = F.normalize(batch["mask"].float(), p=1, dim=1) batch["mask"] = F.normalize( batch["mask"] + self.prior_smoothing_constant, p=1, dim=1 ).long() else: if self.class_set == 5: batch["mask"][batch["mask"] == 5] = 4 batch["mask"][batch["mask"] == 6] = 4 return super().on_after_batch_transfer(batch, dataloader_idx)

© Copyright 2021, Microsoft Corporation. Revision 9fdbb09f.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
v0.5.2
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.1
v0.3.0
v0.2.1
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources