Source code for torchgeo.datamodules.fair1m
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""FAIR1M datamodule."""
from typing import Any, Dict, List, Optional
import pytorch_lightning as pl
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from ..datasets import FAIR1M
from .utils import dataset_split
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]:
"""Custom object detection collate fn to handle variable number of boxes.
Args:
batch: list of sample dicts return by dataset
Returns:
batch dict output
"""
output: Dict[str, Any] = {}
output["image"] = torch.stack([sample["image"] for sample in batch])
output["boxes"] = [sample["boxes"] for sample in batch]
return output
class FAIR1MDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the FAIR1M dataset.
.. versionadded:: 0.2
"""
[docs] def __init__(
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 0,
val_split_pct: float = 0.2,
test_split_pct: float = 0.2,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for FAIR1M based DataLoaders.
Args:
root_dir: The ``root`` arugment to pass to the FAIR1M Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
test_split_pct: What percentage of the dataset to use as a test set
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.val_split_pct = val_split_pct
self.test_split_pct = test_split_pct
[docs] def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
return sample
[docs] def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
dataset = FAIR1M(self.root_dir, transforms=self.preprocess)
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
)
[docs] def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
collate_fn=collate_fn,
)
[docs] def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
collate_fn=collate_fn,
)
[docs] def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
collate_fn=collate_fn,
)