Shortcuts

torchgeo.datamodules

Geospatial DataModules

Chesapeake Land Cover

class torchgeo.datamodules.ChesapeakeCVPRDataModule(train_splits, val_splits, test_splits, batch_size=64, patch_size=256, length=1000, num_workers=0, class_set=7, use_prior_labels=False, prior_smoothing_constant=0.0001, **kwargs)[source]

Bases: GeoDataModule

LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset.

Uses the random splits defined per state to partition tiles into train, val, and test sets.

__init__(train_splits, val_splits, test_splits, batch_size=64, patch_size=256, length=1000, num_workers=0, class_set=7, use_prior_labels=False, prior_smoothing_constant=0.0001, **kwargs)[source]

Initialize a new ChesapeakeCVPRDataModule instance.

Parameters:
  • train_splits (List[str]) – Splits used to train the model, e.g., [“ny-train”].

  • val_splits (List[str]) – Splits used to validate the model, e.g., [“ny-val”].

  • test_splits (List[str]) – Splits used to test the model, e.g., [“ny-test”].

  • batch_size (int) – Size of each mini-batch.

  • patch_size (int) – Size of each patch, either size or (height, width). Should be a multiple of 32 for most segmentation architectures.

  • length (int) – Length of each training epoch.

  • num_workers (int) – Number of workers for parallel data loading.

  • class_set (int) – The high-resolution land cover class set to use (5 or 7).

  • use_prior_labels (bool) – Flag for using a prior over high-resolution classes instead of the high-resolution labels themselves.

  • prior_smoothing_constant (float) – Additive smoothing to add when using prior labels.

  • **kwargs (Any) – Additional keyword arguments passed to ChesapeakeCVPR.

Raises:

ValueError – If use_prior_labels=True is used with class_set=7.

setup(stage)[source]

Set up datasets and samplers.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

on_after_batch_transfer(batch, dataloader_idx)[source]

Apply batch augmentations to the batch after it is transferred to the device.

Parameters:
  • batch (Dict[str, Tensor]) – A batch of data that needs to be altered or augmented.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Returns:

A batch of data.

Return type:

Dict[str, Tensor]

NAIP

class torchgeo.datamodules.NAIPChesapeakeDataModule(batch_size=64, patch_size=256, length=1000, num_workers=0, **kwargs)[source]

Bases: GeoDataModule

LightningDataModule implementation for the NAIP and Chesapeake datasets.

Uses the train/val/test splits from the dataset.

__init__(batch_size=64, patch_size=256, length=1000, num_workers=0, **kwargs)[source]

Initialize a new NAIPChesapeakeDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (int | Tuple[int, int]) – Size of each patch, either size or (height, width).

  • length (int) – Length of each training epoch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to NAIP (prefix keys with naip_) and Chesapeake13 (prefix keys with chesapeake_).

setup(stage)[source]

Set up datasets and samplers.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

plot(*args, **kwargs)[source]

Run NAIP plot method.

Parameters:
  • *args (Any) – Arguments passed to plot method.

  • **kwargs (Any) – Keyword arguments passed to plot method.

Returns:

A matplotlib Figure with the image, ground truth, and predictions.

Return type:

Figure

New in version 0.4.

Non-geospatial DataModules

BigEarthNet

class torchgeo.datamodules.BigEarthNetDataModule(batch_size=64, num_workers=0, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the BigEarthNet dataset.

Uses the train/val/test splits from the dataset.

__init__(batch_size=64, num_workers=0, **kwargs)[source]

Initialize a new BigEarthNetDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to BigEarthNet.

COWC

class torchgeo.datamodules.COWCCountingDataModule(batch_size=64, num_workers=0, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the COWC Counting dataset.

__init__(batch_size=64, num_workers=0, **kwargs)[source]

Initialize a new COWCCountingDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to COWCCounting.

setup(stage)[source]

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

Deep Globe Land Cover Challenge

class torchgeo.datamodules.DeepGlobeLandCoverDataModule(batch_size=64, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the DeepGlobe Land Cover dataset.

Uses the train/test splits from the dataset.

__init__(batch_size=64, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]

Initialize a new DeepGlobeLandCoverDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (Tuple[int, int] | int) – Size of each patch, either size or (height, width). Should be a multiple of 32 for most segmentation architectures.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to DeepGlobeLandCover.

setup(stage)[source]

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

ETCI2021 Flood Detection

class torchgeo.datamodules.ETCI2021DataModule(batch_size=64, num_workers=0, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the ETCI2021 dataset.

Splits the existing train split from the dataset into train/val with 80/20 proportions, then uses the existing val dataset as the test data.

New in version 0.2.

__init__(batch_size=64, num_workers=0, **kwargs)[source]

Initialize a new ETCI2021DataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to ETCI2021.

setup(stage)[source]

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

on_after_batch_transfer(batch, dataloader_idx)[source]

Apply batch augmentations to the batch after it is transferred to the device.

Parameters:
  • batch (Dict[str, Tensor]) – A batch of data that needs to be altered or augmented.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Returns:

A batch of data.

Return type:

Dict[str, Tensor]

EuroSAT

class torchgeo.datamodules.EuroSATDataModule(batch_size=64, num_workers=0, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the EuroSAT dataset.

Uses the train/val/test splits from the dataset.

New in version 0.2.

__init__(batch_size=64, num_workers=0, **kwargs)[source]

Initialize a new EuroSATDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to EuroSAT.

FAIR1M

class torchgeo.datamodules.FAIR1MDataModule(batch_size=64, num_workers=0, val_split_pct=0.2, test_split_pct=0.2, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the FAIR1M dataset.

New in version 0.2.

__init__(batch_size=64, num_workers=0, val_split_pct=0.2, test_split_pct=0.2, **kwargs)[source]

Initialize a new FAIR1MDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • test_split_pct (float) – Percentage of the dataset to use as a test set.

  • **kwargs (Any) – Additional keyword arguments passed to FAIR1M.

setup(stage)[source]

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

GID-15

class torchgeo.datamodules.GID15DataModule(batch_size=64, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the GID-15 dataset.

Uses the train/test splits from the dataset.

New in version 0.4.

__init__(batch_size=64, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]

Initialize a new GID15DataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (Tuple[int, int] | int) – Size of each patch, either size or (height, width). Should be a multiple of 32 for most segmentation architectures.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to GID15.

setup(stage)[source]

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

Inria Aerial Image Labeling

class torchgeo.datamodules.InriaAerialImageLabelingDataModule(batch_size=64, patch_size=64, num_workers=0, val_split_pct=0.1, test_split_pct=0.1, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the InriaAerialImageLabeling dataset.

Uses the train/test splits from the dataset and further splits the train split into train/val splits.

New in version 0.3.

__init__(batch_size=64, patch_size=64, num_workers=0, val_split_pct=0.1, test_split_pct=0.1, **kwargs)[source]

Initialize a new InriaAerialImageLabelingDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (Tuple[int, int] | int) – Size of each patch, either size or (height, width). Should be a multiple of 32 for most segmentation architectures.

  • num_workers (int) – Number of workers for parallel data loading.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • test_split_pct (float) – Percentage of the dataset to use as a test set.

  • **kwargs (Any) – Additional keyword arguments passed to InriaAerialImageLabeling.

setup(stage)[source]

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

LandCover.ai

class torchgeo.datamodules.LandCoverAIDataModule(batch_size=64, num_workers=0, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the LandCover.ai dataset.

Uses the train/val/test splits from the dataset.

__init__(batch_size=64, num_workers=0, **kwargs)[source]

Initialize a new LandCoverAIDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to LandCoverAI.

LoveDA

class torchgeo.datamodules.LoveDADataModule(batch_size=32, num_workers=0, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the LoveDA dataset.

Uses the train/val/test splits from the dataset.

New in version 0.2.

__init__(batch_size=32, num_workers=0, **kwargs)[source]

Initialize a new LoveDADataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to LoveDA.

setup(stage)[source]

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

NASA Marine Debris

class torchgeo.datamodules.NASAMarineDebrisDataModule(batch_size=64, num_workers=0, val_split_pct=0.2, test_split_pct=0.2, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the NASA Marine Debris dataset.

New in version 0.2.

__init__(batch_size=64, num_workers=0, val_split_pct=0.2, test_split_pct=0.2, **kwargs)[source]

Initialize a new NASAMarineDebrisDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • test_split_pct (float) – Percentage of the dataset to use as a test set.

  • **kwargs (Any) – Additional keyword arguments passed to NASAMarineDebris.

setup(stage)[source]

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

OSCD

class torchgeo.datamodules.OSCDDataModule(batch_size=64, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the OSCD dataset.

Uses the train/test splits from the dataset and further splits the train split into train/val splits.

New in version 0.2.

__init__(batch_size=64, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]

Initialize a new OSCDDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (Tuple[int, int] | int) – Size of each patch, either size or (height, width). Should be a multiple of 32 for most segmentation architectures.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to OSCD.

setup(stage)[source]

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

Potsdam

class torchgeo.datamodules.Potsdam2DDataModule(batch_size=64, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the Potsdam2D dataset.

Uses the train/test splits from the dataset.

New in version 0.2.

__init__(batch_size=64, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]

Initialize a new Potsdam2DDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (Tuple[int, int] | int) – Size of each patch, either size or (height, width). Should be a multiple of 32 for most segmentation architectures.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to Potsdam2D.

setup(stage)[source]

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

RESISC45

class torchgeo.datamodules.RESISC45DataModule(batch_size=64, num_workers=0, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the RESISC45 dataset.

Uses the train/val/test splits from the dataset.

__init__(batch_size=64, num_workers=0, **kwargs)[source]

Initialize a new RESISC45DataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to RESISC45.

SEN12MS

class torchgeo.datamodules.SEN12MSDataModule(batch_size=64, num_workers=0, band_set='all', **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the SEN12MS dataset.

Implements 80/20 geographic train/val splits and uses the test split from the classification dataset definitions.

Uses the Simplified IGBP scheme defined in the 2020 Data Fusion Competition. See https://arxiv.org/abs/2002.08254.

DFC2020_CLASS_MAPPING = tensor([ 0,  1,  1,  1,  1,  1,  2,  2,  3,  3,  4,  5,  6,  7,  6,  8,  9, 10])

Mapping from the IGBP class definitions to the DFC2020, taken from the dataloader here: https://github.com/lukasliebel/dfc2020_baseline.

__init__(batch_size=64, num_workers=0, band_set='all', **kwargs)[source]

Initialize a new SEN12MSDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • band_set (str) – Subset of S1/S2 bands to use. Options are: “all”, “s1”, “s2-all”, and “s2-reduced” where the “s2-reduced” set includes: B2, B3, B4, B8, B11, and B12.

  • **kwargs (Any) – Additional keyword arguments passed to SEN12MS.

setup(stage)[source]

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

on_after_batch_transfer(batch, dataloader_idx)[source]

Apply batch augmentations to the batch after it is transferred to the device.

Parameters:
  • batch (Dict[str, Tensor]) – A batch of data that needs to be altered or augmented.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Returns:

A batch of data.

Return type:

Dict[str, Tensor]

So2Sat

class torchgeo.datamodules.So2SatDataModule(batch_size=64, num_workers=0, band_set='all', **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the So2Sat dataset.

Uses the train/val/test splits from the dataset.

__init__(batch_size=64, num_workers=0, band_set='all', **kwargs)[source]

Initialize a new So2SatDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • band_set (str) – One of ‘all’, ‘s1’, or ‘s2’.

  • **kwargs (Any) – Additional keyword arguments passed to So2Sat.

setup(stage)[source]

Set up datasets.

Called at the beginning of fit, validate, test, or predict. During distributed training, this method is called from every process across all the nodes. Setting state here is recommended.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

SpaceNet

class torchgeo.datamodules.SpaceNet1DataModule(batch_size=64, num_workers=0, val_split_pct=0.1, test_split_pct=0.2, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the SpaceNet1 dataset.

Randomly splits into train/val/test.

New in version 0.4.

__init__(batch_size=64, num_workers=0, val_split_pct=0.1, test_split_pct=0.2, **kwargs)[source]

Initialize a new SpaceNet1DataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • test_split_pct (float) – Percentage of the dataset to use as a test set.

  • **kwargs (Any) – Additional keyword arguments passed to SpaceNet1.

setup(stage)[source]

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

on_after_batch_transfer(batch, dataloader_idx)[source]

Apply batch augmentations to the batch after it is transferred to the device.

Parameters:
  • batch (Dict[str, Tensor]) – A batch of data that needs to be altered or augmented.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Returns:

A batch of data.

Return type:

Dict[str, Tensor]

Tropical Cyclone

class torchgeo.datamodules.TropicalCycloneDataModule(batch_size=64, num_workers=0, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the NASA Cyclone dataset.

Implements 80/20 train/val splits based on hurricane storm ids. See setup() for more details.

Changed in version 0.4: Class name changed from CycloneDataModule to TropicalCycloneDataModule to be consistent with TropicalCyclone dataset.

__init__(batch_size=64, num_workers=0, **kwargs)[source]

Initialize a new TropicalCycloneDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to TropicalCyclone.

setup(stage)[source]

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

UC Merced

class torchgeo.datamodules.UCMercedDataModule(batch_size=64, num_workers=0, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the UC Merced dataset.

Uses random train/val/test splits.

__init__(batch_size=64, num_workers=0, **kwargs)[source]

Initialize a new UCMercedDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to UCMerced.

USAVars

class torchgeo.datamodules.USAVarsDataModule(batch_size=64, num_workers=0, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the USAVars dataset.

Uses random train/val/test splits.

New in version 0.3.

__init__(batch_size=64, num_workers=0, **kwargs)[source]

Initialize a new USAVarsDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to USAVars.

Vaihingen

class torchgeo.datamodules.Vaihingen2DDataModule(batch_size=64, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the Vaihingen2D dataset.

Uses the train/test splits from the dataset.

New in version 0.2.

__init__(batch_size=64, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]

Initialize a new Vaihingen2DDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (Tuple[int, int] | int) – Size of each patch, either size or (height, width). Should be a multiple of 32 for most segmentation architectures.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to Vaihingen2D.

setup(stage)[source]

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

xView2

class torchgeo.datamodules.XView2DataModule(batch_size=64, num_workers=0, val_split_pct=0.2, **kwargs)[source]

Bases: NonGeoDataModule

LightningDataModule implementation for the xView2 dataset.

Uses the train/val/test splits from the dataset.

New in version 0.2.

__init__(batch_size=64, num_workers=0, val_split_pct=0.2, **kwargs)[source]

Initialize a new XView2DataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • val_split_pct (float) – What percentage of the dataset to use as a validation set

  • **kwargs (Any) – Additional keyword arguments passed to XView2.

setup(stage)[source]

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

Base Classes

GeoDataModule

class torchgeo.datamodules.GeoDataModule(dataset_class, batch_size=1, patch_size=64, length=1000, num_workers=0, **kwargs)[source]

Bases: LightningDataModule

Base class for data modules containing geospatial information.

New in version 0.4.

__init__(dataset_class, batch_size=1, patch_size=64, length=1000, num_workers=0, **kwargs)[source]

Initialize a new GeoDataModule instance.

Parameters:
  • dataset_class (Type[GeoDataset]) – Class used to instantiate a new dataset.

  • batch_size (int) – Size of each mini-batch.

  • patch_size (int | Tuple[int, int]) – Size of each patch, either size or (height, width).

  • length (int) – Length of each training epoch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to dataset_class

prepare_data()[source]

Download and prepare data.

During distributed training, this method is called only within a single process to avoid corrupted data. This method should not set state since it is not called on every device, use setup() instead.

setup(stage)[source]

Set up datasets and samplers.

Called at the beginning of fit, validate, test, or predict. During distributed training, this method is called from every process across all the nodes. Setting state here is recommended.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

train_dataloader()[source]

Implement one or more PyTorch DataLoaders for training.

Returns:

A collection of data loaders specifying training samples.

Raises:

MisconfigurationException – If setup() does not define a ‘train_dataset’.

Return type:

DataLoader[Dict[str, Tensor]]

val_dataloader()[source]

Implement one or more PyTorch DataLoaders for validation.

Returns:

A collection of data loaders specifying validation samples.

Raises:

MisconfigurationException – If setup() does not define a ‘val_dataset’.

Return type:

DataLoader[Dict[str, Tensor]]

test_dataloader()[source]

Implement one or more PyTorch DataLoaders for testing.

Returns:

A collection of data loaders specifying testing samples.

Raises:

MisconfigurationException – If setup() does not define a ‘test_dataset’.

Return type:

DataLoader[Dict[str, Tensor]]

predict_dataloader()[source]

Implement one or more PyTorch DataLoaders for prediction.

Returns:

A collection of data loaders specifying prediction samples.

Raises:

MisconfigurationException – If setup() does not define a ‘predict_dataset’.

Return type:

DataLoader[Dict[str, Tensor]]

transfer_batch_to_device(batch, device, dataloader_idx)[source]

Transfer batch to device.

Defines how custom data types are moved to the target device.

Parameters:
  • batch (Dict[str, Tensor]) – A batch of data that needs to be transferred to a new device.

  • device (device) – The target device as defined in PyTorch.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Returns:

A reference to the data on the new device.

Return type:

Dict[str, Tensor]

on_after_batch_transfer(batch, dataloader_idx)[source]

Apply batch augmentations to the batch after it is transferred to the device.

Parameters:
  • batch (Dict[str, Tensor]) – A batch of data that needs to be altered or augmented.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Returns:

A batch of data.

Return type:

Dict[str, Tensor]

plot(*args, **kwargs)[source]

Run the plot method of the validation dataset if one exists.

Should only be called during ‘fit’ or ‘validate’ stages as val_dataset may not exist during other stages.

Parameters:
  • *args (Any) – Arguments passed to plot method.

  • **kwargs (Any) – Keyword arguments passed to plot method.

Returns:

A matplotlib Figure with the image, ground truth, and predictions.

Return type:

Figure

NonGeoDataModule

class torchgeo.datamodules.NonGeoDataModule(dataset_class, batch_size=1, num_workers=0, **kwargs)[source]

Bases: LightningDataModule

Base class for data modules lacking geospatial information.

New in version 0.4.

__init__(dataset_class, batch_size=1, num_workers=0, **kwargs)[source]

Initialize a new NonGeoDataModule instance.

Parameters:
  • dataset_class (Type[NonGeoDataset]) – Class used to instantiate a new dataset.

  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to dataset_class

prepare_data()[source]

Download and prepare data.

During distributed training, this method is called only within a single process to avoid corrupted data. This method should not set state since it is not called on every device, use setup() instead.

setup(stage)[source]

Set up datasets.

Called at the beginning of fit, validate, test, or predict. During distributed training, this method is called from every process across all the nodes. Setting state here is recommended.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

train_dataloader()[source]

Implement one or more PyTorch DataLoaders for training.

Returns:

A collection of data loaders specifying training samples.

Raises:

MisconfigurationException – If setup() does not define a ‘train_dataset’.

Return type:

DataLoader[Dict[str, Tensor]]

val_dataloader()[source]

Implement one or more PyTorch DataLoaders for validation.

Returns:

A collection of data loaders specifying validation samples.

Raises:

MisconfigurationException – If setup() does not define a ‘val_dataset’.

Return type:

DataLoader[Dict[str, Tensor]]

test_dataloader()[source]

Implement one or more PyTorch DataLoaders for testing.

Returns:

A collection of data loaders specifying testing samples.

Raises:

MisconfigurationException – If setup() does not define a ‘test_dataset’.

Return type:

DataLoader[Dict[str, Tensor]]

predict_dataloader()[source]

Implement one or more PyTorch DataLoaders for prediction.

Returns:

A collection of data loaders specifying prediction samples.

Raises:

MisconfigurationException – If setup() does not define a ‘predict_dataset’.

Return type:

DataLoader[Dict[str, Tensor]]

on_after_batch_transfer(batch, dataloader_idx)[source]

Apply batch augmentations to the batch after it is transferred to the device.

Parameters:
  • batch (Dict[str, Tensor]) – A batch of data that needs to be altered or augmented.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Returns:

A batch of data.

Return type:

Dict[str, Tensor]

plot(*args, **kwargs)[source]

Run the plot method of the validation dataset if one exists.

Should only be called during ‘fit’ or ‘validate’ stages as val_dataset may not exist during other stages.

Parameters:
  • *args (Any) – Arguments passed to plot method.

  • **kwargs (Any) – Keyword arguments passed to plot method.

Returns:

A matplotlib Figure with the image, ground truth, and predictions.

Return type:

Figure

Utilities

class torchgeo.datamodules.MisconfigurationException[source]

Bases: Exception

Exception used to inform users of misuse with Lightning.

__weakref__

list of weak references to the object (if defined)

Read the Docs v: stable
Versions
latest
stable
v0.4.0
v0.3.1
v0.3.0
v0.2.1
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources