Shortcuts

Source code for torchgeo.datamodules.chesapeake

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Chesapeake Bay High-Resolution Land Cover Project datamodule."""

from typing import Any, Callable, Dict, List, Optional

import torch
import torch.nn.functional as F
from pytorch_lightning.core.datamodule import LightningDataModule
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

from ..datasets import ChesapeakeCVPR, stack_samples
from ..samplers.batch import RandomBatchGeoSampler
from ..samplers.single import GridGeoSampler

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class ChesapeakeCVPRDataModule(LightningDataModule):
    """LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset.

    Uses the random splits defined per state to partition tiles into train, val,
    and test sets.
    """

[docs] def __init__( self, root_dir: str, train_splits: List[str], val_splits: List[str], test_splits: List[str], patches_per_tile: int = 200, patch_size: int = 256, batch_size: int = 64, num_workers: int = 0, class_set: int = 7, use_prior_labels: bool = False, prior_smoothing_constant: float = 1e-4, **kwargs: Any, ) -> None: """Initialize a LightningDataModule for Chesapeake CVPR based DataLoaders. Args: root_dir: The ``root`` arugment to pass to the ChesapeakeCVPR Dataset classes train_splits: The splits used to train the model, e.g. ["ny-train"] val_splits: The splits used to validate the model, e.g. ["ny-val"] test_splits: The splits used to test the model, e.g. ["ny-test"] patches_per_tile: The number of patches per tile to sample patch_size: The size of each patch in pixels (test patches will be 1.5 times this size) batch_size: The batch size to use in all created DataLoaders num_workers: The number of workers to use in all created DataLoaders class_set: The high-resolution land cover class set to use - 5 or 7 use_prior_labels: Flag for using a prior over high-resolution classes instead of the high-resolution labels themselves prior_smoothing_constant: additive smoothing to add when using prior labels Raises: ValueError: if ``use_prior_labels`` is used with ``class_set==7`` """ super().__init__() # type: ignore[no-untyped-call] for state in train_splits + val_splits + test_splits: assert state in ChesapeakeCVPR.splits assert class_set in [5, 7] if use_prior_labels and class_set != 5: raise ValueError( "The pre-generated prior labels are only valid for the 5" + " class set of labels" ) self.root_dir = root_dir self.train_splits = train_splits self.val_splits = val_splits self.test_splits = test_splits self.patches_per_tile = patches_per_tile self.patch_size = patch_size # This is a rough estimate of how large of a patch we will need to sample in # EPSG:3857 in order to guarantee a large enough patch in the local CRS. self.original_patch_size = int(patch_size * 2.0) self.batch_size = batch_size self.num_workers = num_workers self.class_set = class_set self.use_prior_labels = use_prior_labels self.prior_smoothing_constant = prior_smoothing_constant if self.use_prior_labels: self.layers = [ "naip-new", "prior_from_cooccurrences_101_31_no_osm_no_buildings", ] else: self.layers = ["naip-new", "lc"]
[docs] def pad_to( self, size: int = 512, image_value: int = 0, mask_value: int = 0 ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: """Returns a function to perform a padding transform on a single sample. Args: size: output image size image_value: value to pad image with mask_value: value to pad mask with Returns: function to perform padding """ def pad_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: _, height, width = sample["image"].shape assert height <= size and width <= size height_pad = size - height width_pad = size - width # See https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html # for a description of the format of the padding tuple sample["image"] = F.pad( sample["image"], (0, width_pad, 0, height_pad), mode="constant", value=image_value, ) sample["mask"] = F.pad( sample["mask"], (0, width_pad, 0, height_pad), mode="constant", value=mask_value, ) return sample return pad_inner
[docs] def center_crop( self, size: int = 512 ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: """Returns a function to perform a center crop transform on a single sample. Args: size: output image size Returns: function to perform center crop """ def center_crop_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: _, height, width = sample["image"].shape y1 = (height - size) // 2 x1 = (width - size) // 2 sample["image"] = sample["image"][:, y1 : y1 + size, x1 : x1 + size] sample["mask"] = sample["mask"][:, y1 : y1 + size, x1 : x1 + size] return sample return center_crop_inner
[docs] def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Preprocesses a single sample. Args: sample: sample dictionary containing image and mask Returns: preprocessed sample """ sample["image"] = sample["image"] / 255.0 sample["mask"] = sample["mask"].squeeze() if self.use_prior_labels: sample["mask"] = F.normalize(sample["mask"].float(), p=1, dim=0) sample["mask"] = F.normalize( sample["mask"] + self.prior_smoothing_constant, p=1, dim=0 ) else: if self.class_set == 5: sample["mask"][sample["mask"] == 5] = 4 sample["mask"][sample["mask"] == 6] = 4 sample["mask"] = sample["mask"].long() sample["image"] = sample["image"].float() del sample["bbox"] return sample
[docs] def nodata_check( self, size: int = 512 ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: """Returns a function to check for nodata or mis-sized input. Args: size: output image size Returns: function to check for nodata values """ def nodata_check_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: num_channels, height, width = sample["image"].shape if height < size or width < size: sample["image"] = torch.zeros( # type: ignore[attr-defined] (num_channels, size, size) ) sample["mask"] = torch.zeros((size, size)) # type: ignore[attr-defined] return sample return nodata_check_inner
[docs] def prepare_data(self) -> None: """Confirms that the dataset is downloaded on the local node. This method is called once per node, while :func:`setup` is called once per GPU. """ ChesapeakeCVPR( self.root_dir, splits=self.train_splits, layers=self.layers, transforms=None, download=False, checksum=False, )
[docs] def setup(self, stage: Optional[str] = None) -> None: """Create the train/val/test splits based on the original Dataset objects. The splits should be done here vs. in :func:`__init__` per the docs: https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. Args: stage: stage to set up """ train_transforms = Compose( [ self.center_crop(self.patch_size), self.nodata_check(self.patch_size), self.preprocess, ] ) val_transforms = Compose( [ self.center_crop(self.patch_size), self.nodata_check(self.patch_size), self.preprocess, ] ) test_transforms = Compose( [ self.pad_to(self.original_patch_size, image_value=0, mask_value=0), self.preprocess, ] ) self.train_dataset = ChesapeakeCVPR( self.root_dir, splits=self.train_splits, layers=self.layers, transforms=train_transforms, download=False, checksum=False, ) self.val_dataset = ChesapeakeCVPR( self.root_dir, splits=self.val_splits, layers=self.layers, transforms=val_transforms, download=False, checksum=False, ) self.test_dataset = ChesapeakeCVPR( self.root_dir, splits=self.test_splits, layers=self.layers, transforms=test_transforms, download=False, checksum=False, )
[docs] def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training. Returns: training data loader """ sampler = RandomBatchGeoSampler( self.train_dataset, size=self.original_patch_size, batch_size=self.batch_size, length=self.patches_per_tile * len(self.train_dataset), ) return DataLoader( self.train_dataset, batch_sampler=sampler, num_workers=self.num_workers, collate_fn=stack_samples, )
[docs] def val_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for validation. Returns: validation data loader """ sampler = GridGeoSampler( self.val_dataset, size=self.original_patch_size, stride=self.original_patch_size, ) return DataLoader( self.val_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=self.num_workers, collate_fn=stack_samples, )
[docs] def test_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for testing. Returns: testing data loader """ sampler = GridGeoSampler( self.test_dataset, size=self.original_patch_size, stride=self.original_patch_size, ) return DataLoader( self.test_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=self.num_workers, collate_fn=stack_samples, )

© Copyright 2021, Microsoft Corporation. Revision e1285e6c.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources