Source code for torchgeo.datamodules.cowc
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""COWC datamodule."""
from typing import Any, Dict, Optional
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torch import Generator
from torch.utils.data import DataLoader, random_split
from ..datasets import COWCCounting
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class COWCCountingDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the COWC Counting dataset."""
[docs] def __init__(
self,
root_dir: str,
seed: int,
batch_size: int = 64,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for COWC Counting based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the COWCCounting Dataset class
seed: The seed value to use when doing the dataset random_split
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
"""
super().__init__()
self.root_dir = root_dir
self.seed = seed
self.batch_size = batch_size
self.num_workers = num_workers
[docs] def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: dictionary containing image and target
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0 # scale to [0, 1]
if "label" in sample:
sample["label"] = sample["label"].float()
return sample
[docs] def prepare_data(self) -> None:
"""Initialize the main ``Dataset`` objects for use in :func:`setup`.
This includes optionally downloading the dataset. This is done once per node,
while :func:`setup` is done once per GPU.
"""
COWCCounting(self.root_dir, download=False)
[docs] def setup(self, stage: Optional[str] = None) -> None:
"""Create the train/val/test splits based on the original Dataset objects.
The splits should be done here vs. in :func:`__init__` per the docs:
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup.
Args:
stage: stage to set up
"""
train_val_dataset = COWCCounting(
self.root_dir, split="train", transforms=self.preprocess
)
self.test_dataset = COWCCounting(
self.root_dir, split="test", transforms=self.preprocess
)
self.train_dataset, self.val_dataset = random_split(
train_val_dataset,
[len(train_val_dataset) - len(self.test_dataset), len(self.test_dataset)],
generator=Generator().manual_seed(self.seed),
)
[docs] def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
[docs] def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
[docs] def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
[docs] def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.COWC.plot`.
.. versionadded:: 0.2
"""
return self.test_dataset.plot(*args, **kwargs)