Source code for torchgeo.datamodules.eurosat
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""EuroSAT datamodule."""
from typing import Any, Dict, Optional
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize
from ..datasets import EuroSAT
class EuroSATDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the EuroSAT dataset.
Uses the train/val/test splits from the dataset.
.. versionadded:: 0.2
"""
band_means = torch.tensor( # type: ignore[attr-defined]
[
1354.40546513,
1118.24399958,
1042.92983953,
947.62620298,
1199.47283961,
1999.79090914,
2369.22292565,
2296.82608323,
732.08340178,
12.11327804,
1819.01027855,
1118.92391149,
2594.14080798,
]
)
band_stds = torch.tensor( # type: ignore[attr-defined]
[
245.71762908,
333.00778264,
395.09249139,
593.75055589,
566.4170017,
861.18399006,
1086.63139075,
1117.98170791,
404.91978886,
4.77584468,
1002.58768311,
761.30323499,
1231.58581042,
]
)
[docs] def __init__(
self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a LightningDataModule for EuroSAT based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the EuroSAT Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.norm = Normalize(self.band_means, self.band_stds)
[docs] def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] = self.norm(sample["image"])
return sample
[docs] def prepare_data(self) -> None:
"""Make sure that the dataset is downloaded.
This method is only called once per run.
"""
EuroSAT(self.root_dir)
[docs] def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
self.train_dataset = EuroSAT(self.root_dir, "train", transforms=transforms)
self.val_dataset = EuroSAT(self.root_dir, "val", transforms=transforms)
self.test_dataset = EuroSAT(self.root_dir, "test", transforms=transforms)
[docs] def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
[docs] def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
[docs] def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
[docs] def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.EuroSAT.plot`."""
return self.val_dataset.plot(*args, **kwargs)