Shortcuts

Source code for torchgeo.datamodules.nasa_marine_debris

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""NASA Marine Debris datamodule."""

from typing import Any, Dict, List, Optional

import pytorch_lightning as pl
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

from ..datasets import NASAMarineDebris
from .utils import dataset_split

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]:
    """Custom object detection collate fn to handle variable boxes.

    Args:
        batch: list of sample dicts return by dataset

    Returns:
        batch dict output
    """
    output: Dict[str, Any] = {}
    output["image"] = torch.stack([sample["image"] for sample in batch])
    output["boxes"] = [sample["boxes"] for sample in batch]
    return output


class NASAMarineDebrisDataModule(pl.LightningDataModule):
    """LightningDataModule implementation for the NASA Marine Debris dataset."""

[docs] def __init__( self, root_dir: str, batch_size: int = 64, num_workers: int = 0, val_split_pct: float = 0.2, test_split_pct: float = 0.2, **kwargs: Any, ) -> None: """Initialize a LightningDataModule for NASA Marine Debris based DataLoaders. Args: root_dir: The ``root`` argument to pass to the Dataset class batch_size: The batch size to use in all created DataLoaders num_workers: The number of workers to use in all created DataLoaders val_split_pct: What percentage of the dataset to use as a validation set test_split_pct: What percentage of the dataset to use as a test set """ super().__init__() # type: ignore[no-untyped-call] self.root_dir = root_dir self.batch_size = batch_size self.num_workers = num_workers self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct
[docs] def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset. Args: sample: input image dictionary Returns: preprocessed sample """ sample["image"] = sample["image"].float() sample["image"] /= 255.0 return sample
[docs] def prepare_data(self) -> None: """Make sure that the dataset is downloaded. This method is only called once per run. """ NASAMarineDebris(self.root_dir, checksum=False)
[docs] def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. This method is called once per GPU per run. Args: stage: stage to set up """ transforms = Compose([self.preprocess]) dataset = NASAMarineDebris(self.root_dir, transforms=transforms) self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct )
[docs] def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training. Returns: training data loader """ return DataLoader( self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, collate_fn=collate_fn, )
[docs] def val_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for validation. Returns: validation data loader """ return DataLoader( self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, collate_fn=collate_fn, )
[docs] def test_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for testing. Returns: testing data loader """ return DataLoader( self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, collate_fn=collate_fn, )

© Copyright 2021, Microsoft Corporation. Revision e1285e6c.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources