Shortcuts

Source code for torchgeo.datamodules.ftw

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""FTW datamodule."""

from typing import Any

import kornia.augmentation as K
import torch

from ..datasets import FieldsOfTheWorld
from .geo import NonGeoDataModule


[docs]class FieldsOfTheWorldDataModule(NonGeoDataModule): """LightningDataModule implementation for the FTW dataset. .. versionadded:: 0.7 """ mean = torch.tensor([0]) std = torch.tensor([3000])
[docs] def __init__( self, train_countries: list[str] = ['austria'], val_countries: list[str] = ['austria'], test_countries: list[str] = ['austria'], batch_size: int = 64, num_workers: int = 0, **kwargs: Any, ) -> None: """Initialize a new FTWDataModule instance. Args: train_countries: List of countries to use for training. val_countries: List of countries to use for validation. test_countries: List of countries to use for testing. batch_size: Size of each mini-batch. num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.FieldsOfTheWorld`. Raises: AssertionError: If 'countries' are specified in kwargs """ assert 'countries' not in kwargs, ( "Please specify 'train_countries', 'val_countries', and 'test_countries' instead of 'countries' inside kwargs" ) super().__init__(FieldsOfTheWorld, batch_size, num_workers, **kwargs) self.train_countries = train_countries self.val_countries = val_countries self.test_countries = test_countries self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), K.RandomSharpness(p=0.5), data_keys=None, keepdim=True, ) self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True )
[docs] def setup(self, stage: str) -> None: """Set up datasets. Args: stage: Either 'fit', 'validate', or 'test'. """ if stage in ['fit', 'validate']: self.train_dataset = FieldsOfTheWorld( split='train', countries=self.train_countries, **self.kwargs ) self.val_dataset = FieldsOfTheWorld( split='val', countries=self.val_countries, **self.kwargs ) if stage in ['test']: self.test_dataset = FieldsOfTheWorld( split='test', countries=self.test_countries, **self.kwargs )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources