Shortcuts

Source code for torchgeo.datamodules.cyclone

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Tropical Cyclone Wind Estimation Competition datamodule."""

from typing import Any, Dict, Optional

import pytorch_lightning as pl
import torch
from sklearn.model_selection import GroupShuffleSplit
from torch.utils.data import DataLoader, Subset

from ..datasets import TropicalCycloneWindEstimation

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class CycloneDataModule(pl.LightningDataModule):
    """LightningDataModule implementation for the NASA Cyclone dataset.

    Implements 80/20 train/val splits based on hurricane storm ids.
    See :func:`setup` for more details.
    """

[docs] def __init__( self, root_dir: str, seed: int, batch_size: int = 64, num_workers: int = 0, api_key: Optional[str] = None, **kwargs: Any, ) -> None: """Initialize a LightningDataModule for NASA Cyclone based DataLoaders. Args: root_dir: The ``root`` arugment to pass to the TropicalCycloneWindEstimation Datasets classes seed: The seed value to use when doing the sklearn based GroupShuffleSplit batch_size: The batch size to use in all created DataLoaders num_workers: The number of workers to use in all created DataLoaders api_key: The RadiantEarth MLHub API key to use if the dataset needs to be downloaded """ super().__init__() # type: ignore[no-untyped-call] self.root_dir = root_dir self.seed = seed self.batch_size = batch_size self.num_workers = num_workers self.api_key = api_key
[docs] def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset. Args: sample: dictionary containing image and target Returns: preprocessed sample """ sample["image"] = sample["image"] / 255.0 # scale to [0,1] sample["image"] = ( sample["image"].unsqueeze(0).repeat(3, 1, 1) ) # convert to 3 channel sample["label"] = torch.as_tensor( # type: ignore[attr-defined] sample["label"] ).float() return sample
[docs] def prepare_data(self) -> None: """Initialize the main ``Dataset`` objects for use in :func:`setup`. This includes optionally downloading the dataset. This is done once per node, while :func:`setup` is done once per GPU. """ TropicalCycloneWindEstimation( self.root_dir, split="train", transforms=self.custom_transform, download=self.api_key is not None, api_key=self.api_key, )
[docs] def setup(self, stage: Optional[str] = None) -> None: """Create the train/val/test splits based on the original Dataset objects. The splits should be done here vs. in :func:`__init__` per the docs: https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. We split samples between train/val by the ``storm_id`` property. I.e. all samples with the same ``storm_id`` value will be either in the train or the val split. This is important to test one type of generalizability -- given a new storm, can we predict its windspeed. The test set, however, contains *some* storms from the training set (specifically, the latter parts of the storms) as well as some novel storms. Args: stage: stage to set up """ self.all_train_dataset = TropicalCycloneWindEstimation( self.root_dir, split="train", transforms=self.custom_transform, download=False, ) self.all_test_dataset = TropicalCycloneWindEstimation( self.root_dir, split="test", transforms=self.custom_transform, download=False, ) storm_ids = [] for item in self.all_train_dataset.collection: storm_id = item["href"].split("/")[0].split("_")[-2] storm_ids.append(storm_id) train_indices, val_indices = next( GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split( storm_ids, groups=storm_ids ) ) self.train_dataset = Subset(self.all_train_dataset, train_indices) self.val_dataset = Subset(self.all_train_dataset, val_indices) self.test_dataset = Subset( self.all_test_dataset, range(len(self.all_test_dataset)) )
[docs] def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training. Returns: training data loader """ return DataLoader( self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, )
[docs] def val_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for validation. Returns: validation data loader """ return DataLoader( self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, )
[docs] def test_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for testing. Returns: testing data loader """ return DataLoader( self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, )

© Copyright 2021, Microsoft Corporation. Revision e1285e6c.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources