Shortcuts

Source code for torchgeo.datamodules.bigearthnet

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""BigEarthNet datamodule."""

from typing import Any, Dict, Optional

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

from ..datasets import BigEarthNet

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class BigEarthNetDataModule(pl.LightningDataModule):
    """LightningDataModule implementation for the BigEarthNet dataset.

    Uses the train/val/test splits from the dataset.
    """

    # (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12)
    # min/max band statistics computed on 100k random samples
    band_mins_raw = torch.tensor(  # type: ignore[attr-defined]
        [-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]
    )
    band_maxs_raw = torch.tensor(  # type: ignore[attr-defined]
        [
            31.0,
            35.0,
            18556.0,
            20528.0,
            18976.0,
            17874.0,
            16611.0,
            16512.0,
            16394.0,
            16672.0,
            16141.0,
            16097.0,
            15336.0,
            15203.0,
        ]
    )

    # min/max band statistics computed by percentile clipping the
    # above to samples to [2, 98]
    band_mins = torch.tensor(  # type: ignore[attr-defined]
        [-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
    )
    band_maxs = torch.tensor(  # type: ignore[attr-defined]
        [
            6.0,
            16.0,
            9859.0,
            12872.0,
            13163.0,
            14445.0,
            12477.0,
            12563.0,
            12289.0,
            15596.0,
            12183.0,
            9458.0,
            5897.0,
            5544.0,
        ]
    )

[docs] def __init__( self, root_dir: str, bands: str = "all", num_classes: int = 19, batch_size: int = 64, num_workers: int = 0, **kwargs: Any, ) -> None: """Initialize a LightningDataModule for BigEarthNet based DataLoaders. Args: root_dir: The ``root`` arugment to pass to the BigEarthNet Dataset classes bands: load Sentinel-1 bands, Sentinel-2, or both. one of {s1, s2, all} num_classes: number of classes to load in target. one of {19, 43} batch_size: The batch size to use in all created DataLoaders num_workers: The number of workers to use in all created DataLoaders """ super().__init__() # type: ignore[no-untyped-call] self.root_dir = root_dir self.bands = bands self.num_classes = num_classes self.batch_size = batch_size self.num_workers = num_workers if bands == "all": self.mins = self.band_mins[:, None, None] self.maxs = self.band_maxs[:, None, None] elif bands == "s1": self.mins = self.band_mins[:2, None, None] self.maxs = self.band_maxs[:2, None, None] else: self.mins = self.band_mins[2:, None, None] self.maxs = self.band_maxs[2:, None, None]
[docs] def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset.""" sample["image"] = sample["image"].float() sample["image"] = (sample["image"] - self.mins) / (self.maxs - self.mins) sample["image"] = torch.clip( # type: ignore[attr-defined] sample["image"], min=0.0, max=1.0 ) return sample
[docs] def prepare_data(self) -> None: """Make sure that the dataset is downloaded. This method is only called once per run. """ BigEarthNet(self.root_dir, split="train", bands=self.bands, checksum=False)
[docs] def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. This method is called once per GPU per run. """ transforms = Compose([self.preprocess]) self.train_dataset = BigEarthNet( self.root_dir, split="train", bands=self.bands, num_classes=self.num_classes, transforms=transforms, ) self.val_dataset = BigEarthNet( self.root_dir, split="val", bands=self.bands, num_classes=self.num_classes, transforms=transforms, ) self.test_dataset = BigEarthNet( self.root_dir, split="test", bands=self.bands, num_classes=self.num_classes, transforms=transforms, )
[docs] def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training.""" return DataLoader( self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, )
[docs] def val_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for validation.""" return DataLoader( self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, )
[docs] def test_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for testing.""" return DataLoader( self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, )
[docs] def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.BigEarthNet.plot`.""" return self.val_dataset.plot(*args, **kwargs)

© Copyright 2021, Microsoft Corporation. Revision e1285e6c.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources