
Source code for torchgeo.datamodules.nasa_marine_debris

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""NASA Marine Debris datamodule."""

from typing import Any, Dict, List, Optional

import pytorch_lightning as pl
import torch
from torch import Tensor
from import DataLoader
from torchvision.transforms import Compose

from ..datasets import NASAMarineDebris
from .utils import dataset_split

DataLoader.__module__ = ""

def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]:
    """Custom object detection collate fn to handle variable boxes.

        batch: list of sample dicts return by dataset

        batch dict output
    output: Dict[str, Any] = {}
    output["image"] = torch.stack([sample["image"] for sample in batch])
    output["boxes"] = [sample["boxes"] for sample in batch]
    return output

class NASAMarineDebrisDataModule(pl.LightningDataModule):
    """LightningDataModule implementation for the NASA Marine Debris dataset."""

[docs] def __init__( self, root_dir: str, batch_size: int = 64, num_workers: int = 0, val_split_pct: float = 0.2, test_split_pct: float = 0.2, **kwargs: Any, ) -> None: """Initialize a LightningDataModule for NASA Marine Debris based DataLoaders. Args: root_dir: The ``root`` argument to pass to the Dataset class batch_size: The batch size to use in all created DataLoaders num_workers: The number of workers to use in all created DataLoaders val_split_pct: What percentage of the dataset to use as a validation set test_split_pct: What percentage of the dataset to use as a test set """ super().__init__() # type: ignore[no-untyped-call] self.root_dir = root_dir self.batch_size = batch_size self.num_workers = num_workers self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct
[docs] def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset. Args: sample: input image dictionary Returns: preprocessed sample """ sample["image"] = sample["image"].float() sample["image"] /= 255.0 return sample
[docs] def prepare_data(self) -> None: """Make sure that the dataset is downloaded. This method is only called once per run. """ NASAMarineDebris(self.root_dir, checksum=False)
[docs] def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. This method is called once per GPU per run. Args: stage: stage to set up """ transforms = Compose([self.preprocess]) dataset = NASAMarineDebris(self.root_dir, transforms=transforms) self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct )
[docs] def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training. Returns: training data loader """ return DataLoader( self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, collate_fn=collate_fn, )
[docs] def val_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for validation. Returns: validation data loader """ return DataLoader( self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, collate_fn=collate_fn, )
[docs] def test_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for testing. Returns: testing data loader """ return DataLoader( self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, collate_fn=collate_fn, )

© Copyright 2021, Microsoft Corporation. Revision e1285e6c.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.2.0
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources