Shortcuts

Source code for torchgeo.datamodules.geo

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Base classes for all :mod:`torchgeo` data modules."""

from typing import Any, Callable, Dict, Optional, Tuple, Type, Union

import kornia.augmentation as K
import matplotlib.pyplot as plt
import torch
from pytorch_lightning import LightningDataModule
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, default_collate

from ..datasets import GeoDataset, NonGeoDataset, stack_samples
from ..samplers import (
    BatchGeoSampler,
    GeoSampler,
    GridGeoSampler,
    RandomBatchGeoSampler,
)
from ..transforms import AugmentationSequential
from .utils import MisconfigurationException


[docs]class GeoDataModule(LightningDataModule): """Base class for data modules containing geospatial information. .. versionadded:: 0.4 """ mean = torch.tensor(0) std = torch.tensor(255)
[docs] def __init__( self, dataset_class: Type[GeoDataset], batch_size: int = 1, patch_size: Union[int, Tuple[int, int]] = 64, length: int = 1000, num_workers: int = 0, **kwargs: Any, ) -> None: """Initialize a new GeoDataModule instance. Args: dataset_class: Class used to instantiate a new dataset. batch_size: Size of each mini-batch. patch_size: Size of each patch, either ``size`` or ``(height, width)``. length: Length of each training epoch. num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to ``dataset_class`` """ super().__init__() self.dataset_class = dataset_class self.batch_size = batch_size self.patch_size = patch_size self.length = length self.num_workers = num_workers self.kwargs = kwargs # Datasets self.dataset: Optional[Dataset[Dict[str, Tensor]]] = None self.train_dataset: Optional[Dataset[Dict[str, Tensor]]] = None self.val_dataset: Optional[Dataset[Dict[str, Tensor]]] = None self.test_dataset: Optional[Dataset[Dict[str, Tensor]]] = None self.predict_dataset: Optional[Dataset[Dict[str, Tensor]]] = None # Samplers self.sampler: Optional[GeoSampler] = None self.train_sampler: Optional[GeoSampler] = None self.val_sampler: Optional[GeoSampler] = None self.test_sampler: Optional[GeoSampler] = None self.predict_sampler: Optional[GeoSampler] = None # Batch samplers self.batch_sampler: Optional[BatchGeoSampler] = None self.train_batch_sampler: Optional[BatchGeoSampler] = None self.val_batch_sampler: Optional[BatchGeoSampler] = None self.test_batch_sampler: Optional[BatchGeoSampler] = None self.predict_batch_sampler: Optional[BatchGeoSampler] = None # Data loaders self.train_batch_size: Optional[int] = None self.val_batch_size: Optional[int] = None self.test_batch_size: Optional[int] = None self.predict_batch_size: Optional[int] = None # Collation self.collate_fn = stack_samples # Data augmentation Transform = Callable[[Dict[str, Tensor]], Dict[str, Tensor]] self.aug: Transform = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=["image"] ) self.train_aug: Optional[Transform] = None self.val_aug: Optional[Transform] = None self.test_aug: Optional[Transform] = None self.predict_aug: Optional[Transform] = None
[docs] def prepare_data(self) -> None: """Download and prepare data. During distributed training, this method is called only within a single process to avoid corrupted data. This method should not set state since it is not called on every device, use :meth:`setup` instead. """ if self.kwargs.get("download", False): self.dataset_class(**self.kwargs)
[docs] def setup(self, stage: str) -> None: """Set up datasets and samplers. Called at the beginning of fit, validate, test, or predict. During distributed training, this method is called from every process across all the nodes. Setting state here is recommended. Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ if stage in ["fit"]: self.train_dataset = self.dataset_class( # type: ignore[call-arg] split="train", **self.kwargs ) self.train_batch_sampler = RandomBatchGeoSampler( self.train_dataset, self.patch_size, self.batch_size, self.length ) if stage in ["fit", "validate"]: self.val_dataset = self.dataset_class( # type: ignore[call-arg] split="val", **self.kwargs ) self.val_sampler = GridGeoSampler( self.val_dataset, self.patch_size, self.patch_size ) if stage in ["test"]: self.test_dataset = self.dataset_class( # type: ignore[call-arg] split="test", **self.kwargs ) self.test_sampler = GridGeoSampler( self.test_dataset, self.patch_size, self.patch_size )
[docs] def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for training. Returns: A collection of data loaders specifying training samples. Raises: MisconfigurationException: If :meth:`setup` does not define a 'train_dataset'. """ dataset = self.train_dataset or self.dataset sampler = self.train_sampler or self.sampler batch_sampler = self.train_batch_sampler or self.batch_sampler if dataset is not None and (sampler or batch_sampler) is not None: batch_size = self.train_batch_size or self.batch_size if batch_sampler is not None: batch_size = 1 sampler = None return DataLoader( dataset=dataset, batch_size=batch_size, sampler=sampler, batch_sampler=batch_sampler, num_workers=self.num_workers, collate_fn=self.collate_fn, ) else: msg = f"{self.__class__.__name__}.setup does not define a 'train_dataset'" raise MisconfigurationException(msg)
[docs] def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for validation. Returns: A collection of data loaders specifying validation samples. Raises: MisconfigurationException: If :meth:`setup` does not define a 'val_dataset'. """ dataset = self.val_dataset or self.dataset sampler = self.val_sampler or self.sampler batch_sampler = self.val_batch_sampler or self.batch_sampler if dataset is not None and (sampler or batch_sampler) is not None: batch_size = self.val_batch_size or self.batch_size if batch_sampler is not None: batch_size = 1 sampler = None return DataLoader( dataset=dataset, batch_size=batch_size, sampler=sampler, batch_sampler=batch_sampler, num_workers=self.num_workers, collate_fn=self.collate_fn, ) else: msg = f"{self.__class__.__name__}.setup does not define a 'val_dataset'" raise MisconfigurationException(msg)
[docs] def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for testing. Returns: A collection of data loaders specifying testing samples. Raises: MisconfigurationException: If :meth:`setup` does not define a 'test_dataset'. """ dataset = self.test_dataset or self.dataset sampler = self.test_sampler or self.sampler batch_sampler = self.test_batch_sampler or self.batch_sampler if dataset is not None and (sampler or batch_sampler) is not None: batch_size = self.test_batch_size or self.batch_size if batch_sampler is not None: batch_size = 1 sampler = None return DataLoader( dataset=dataset, batch_size=batch_size, sampler=sampler, batch_sampler=batch_sampler, num_workers=self.num_workers, collate_fn=self.collate_fn, ) else: msg = f"{self.__class__.__name__}.setup does not define a 'test_dataset'" raise MisconfigurationException(msg)
[docs] def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for prediction. Returns: A collection of data loaders specifying prediction samples. Raises: MisconfigurationException: If :meth:`setup` does not define a 'predict_dataset'. """ dataset = self.predict_dataset or self.dataset sampler = self.predict_sampler or self.sampler batch_sampler = self.predict_batch_sampler or self.batch_sampler if dataset is not None and (sampler or batch_sampler) is not None: batch_size = self.predict_batch_size or self.batch_size if batch_sampler is not None: batch_size = 1 sampler = None return DataLoader( dataset=dataset, batch_size=batch_size, sampler=sampler, batch_sampler=batch_sampler, num_workers=self.num_workers, collate_fn=self.collate_fn, ) else: msg = f"{self.__class__.__name__}.setup does not define a 'predict_dataset'" raise MisconfigurationException(msg)
[docs] def transfer_batch_to_device( self, batch: Dict[str, Tensor], device: torch.device, dataloader_idx: int ) -> Dict[str, Tensor]: """Transfer batch to device. Defines how custom data types are moved to the target device. Args: batch: A batch of data that needs to be transferred to a new device. device: The target device as defined in PyTorch. dataloader_idx: The index of the dataloader to which the batch belongs. Returns: A reference to the data on the new device. """ # Non-Tensor values cannot be moved to a device del batch["crs"] del batch["bbox"] batch = super().transfer_batch_to_device(batch, device, dataloader_idx) return batch
[docs] def on_after_batch_transfer( self, batch: Dict[str, Tensor], dataloader_idx: int ) -> Dict[str, Tensor]: """Apply batch augmentations to the batch after it is transferred to the device. Args: batch: A batch of data that needs to be altered or augmented. dataloader_idx: The index of the dataloader to which the batch belongs. Returns: A batch of data. """ if self.trainer: if self.trainer.training: aug = self.train_aug or self.aug elif self.trainer.validating or self.trainer.sanity_checking: aug = self.val_aug or self.aug elif self.trainer.testing: aug = self.test_aug or self.aug elif self.trainer.predicting: aug = self.predict_aug or self.aug batch = aug(batch) return batch
[docs] def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run the plot method of the validation dataset if one exists. Should only be called during 'fit' or 'validate' stages as ``val_dataset`` may not exist during other stages. Args: *args: Arguments passed to plot method. **kwargs: Keyword arguments passed to plot method. Returns: A matplotlib Figure with the image, ground truth, and predictions. """ dataset = self.val_dataset or self.dataset if dataset is not None: if hasattr(dataset, "plot"): return dataset.plot(*args, **kwargs)
[docs]class NonGeoDataModule(LightningDataModule): """Base class for data modules lacking geospatial information. .. versionadded:: 0.4 """ mean = torch.tensor(0) std = torch.tensor(255)
[docs] def __init__( self, dataset_class: Type[NonGeoDataset], batch_size: int = 1, num_workers: int = 0, **kwargs: Any, ) -> None: """Initialize a new NonGeoDataModule instance. Args: dataset_class: Class used to instantiate a new dataset. batch_size: Size of each mini-batch. num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to ``dataset_class`` """ super().__init__() self.dataset_class = dataset_class self.batch_size = batch_size self.num_workers = num_workers self.kwargs = kwargs # Datasets self.dataset: Optional[Dataset[Dict[str, Tensor]]] = None self.train_dataset: Optional[Dataset[Dict[str, Tensor]]] = None self.val_dataset: Optional[Dataset[Dict[str, Tensor]]] = None self.test_dataset: Optional[Dataset[Dict[str, Tensor]]] = None self.predict_dataset: Optional[Dataset[Dict[str, Tensor]]] = None # Data loaders self.train_batch_size: Optional[int] = None self.val_batch_size: Optional[int] = None self.test_batch_size: Optional[int] = None self.predict_batch_size: Optional[int] = None # Collation self.collate_fn = default_collate # Data augmentation Transform = Callable[[Dict[str, Tensor]], Dict[str, Tensor]] self.aug: Transform = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=["image"] ) self.train_aug: Optional[Transform] = None self.val_aug: Optional[Transform] = None self.test_aug: Optional[Transform] = None self.predict_aug: Optional[Transform] = None
[docs] def prepare_data(self) -> None: """Download and prepare data. During distributed training, this method is called only within a single process to avoid corrupted data. This method should not set state since it is not called on every device, use :meth:`setup` instead. """ if self.kwargs.get("download", False): self.dataset_class(**self.kwargs)
[docs] def setup(self, stage: str) -> None: """Set up datasets. Called at the beginning of fit, validate, test, or predict. During distributed training, this method is called from every process across all the nodes. Setting state here is recommended. Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ if stage in ["fit"]: self.train_dataset = self.dataset_class( # type: ignore[call-arg] split="train", **self.kwargs ) if stage in ["fit", "validate"]: self.val_dataset = self.dataset_class( # type: ignore[call-arg] split="val", **self.kwargs ) if stage in ["test"]: self.test_dataset = self.dataset_class( # type: ignore[call-arg] split="test", **self.kwargs )
[docs] def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for training. Returns: A collection of data loaders specifying training samples. Raises: MisconfigurationException: If :meth:`setup` does not define a 'train_dataset'. """ dataset = self.train_dataset or self.dataset if dataset is not None: return DataLoader( dataset=dataset, batch_size=self.train_batch_size or self.batch_size, shuffle=True, num_workers=self.num_workers, collate_fn=self.collate_fn, ) else: msg = f"{self.__class__.__name__}.setup does not define a 'train_dataset'" raise MisconfigurationException(msg)
[docs] def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for validation. Returns: A collection of data loaders specifying validation samples. Raises: MisconfigurationException: If :meth:`setup` does not define a 'val_dataset'. """ dataset = self.val_dataset or self.dataset if dataset is not None: return DataLoader( dataset=dataset, batch_size=self.val_batch_size or self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=self.collate_fn, ) else: msg = f"{self.__class__.__name__}.setup does not define a 'val_dataset'" raise MisconfigurationException(msg)
[docs] def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for testing. Returns: A collection of data loaders specifying testing samples. Raises: MisconfigurationException: If :meth:`setup` does not define a 'test_dataset'. """ dataset = self.test_dataset or self.dataset if dataset is not None: return DataLoader( dataset=dataset, batch_size=self.test_batch_size or self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=self.collate_fn, ) else: msg = f"{self.__class__.__name__}.setup does not define a 'test_dataset'" raise MisconfigurationException(msg)
[docs] def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for prediction. Returns: A collection of data loaders specifying prediction samples. Raises: MisconfigurationException: If :meth:`setup` does not define a 'predict_dataset'. """ dataset = self.predict_dataset or self.dataset if dataset is not None: return DataLoader( dataset=dataset, batch_size=self.predict_batch_size or self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=self.collate_fn, ) else: msg = f"{self.__class__.__name__}.setup does not define a 'predict_dataset'" raise MisconfigurationException(msg)
[docs] def on_after_batch_transfer( self, batch: Dict[str, Tensor], dataloader_idx: int ) -> Dict[str, Tensor]: """Apply batch augmentations to the batch after it is transferred to the device. Args: batch: A batch of data that needs to be altered or augmented. dataloader_idx: The index of the dataloader to which the batch belongs. Returns: A batch of data. """ if self.trainer: if self.trainer.training: aug = self.train_aug or self.aug elif self.trainer.validating or self.trainer.sanity_checking: aug = self.val_aug or self.aug elif self.trainer.testing: aug = self.test_aug or self.aug elif self.trainer.predicting: aug = self.predict_aug or self.aug batch = aug(batch) return batch
[docs] def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run the plot method of the validation dataset if one exists. Should only be called during 'fit' or 'validate' stages as ``val_dataset`` may not exist during other stages. Args: *args: Arguments passed to plot method. **kwargs: Keyword arguments passed to plot method. Returns: A matplotlib Figure with the image, ground truth, and predictions. """ dataset = self.dataset or self.val_dataset if dataset is not None: if hasattr(dataset, "plot"): return dataset.plot(*args, **kwargs)

© Copyright 2021, Microsoft Corporation. Revision 671737fd.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.1
v0.3.0
v0.2.1
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources