Shortcuts

Source code for torchgeo.datamodules.treesatai

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""TreeSatAI datamodules."""

from typing import Any

import kornia.augmentation as K
import torch
from torch import Tensor
from torch.utils.data import random_split

from ..datasets import TreeSatAI
from ..samplers.utils import _to_tuple
from .geo import NonGeoDataModule

# https://git.tu-berlin.de/rsim/treesat_benchmark/-/blob/master/configs/multimodal/AllModes_Xformer_ResnetScratch_v8.json
means = {
    'aerial': [
        151.26809261440323,
        93.1159469148246,
        85.05016794624635,
        81.0471576353153,
    ],
    's1': [-6.933713050794077, -12.628564056094067, 0.47448312147709354],
    's2': [
        231.43385024546893,
        376.94788434611434,
        241.03688288984037,
        2809.8421354087955,
        616.5578221193639,
        2104.3826773960823,
        2695.083864757169,
        2969.868417923599,
        1306.0814241837832,
        587.0608264363341,
        249.1888624097736,
        2950.2294375352285,
    ],
}
stds = {
    'aerial': [
        48.70879149145466,
        33.59622314610158,
        28.000497087051126,
        33.683983599997724,
    ],
    's1': [87.8762246957811, 47.03070478433704, 1.297291303623673],
    's2': [
        123.16515044781909,
        139.78991338362886,
        140.6154081184225,
        786.4508872594147,
        202.51268536579394,
        530.7255451201194,
        710.2650071967689,
        777.4421400779165,
        424.30312334282684,
        247.21468849049668,
        122.80062680549261,
        702.7404237034002,
    ],
}


[docs]class TreeSatAIDataModule(NonGeoDataModule): """LightningDataModule implementation for the TreeSatAI dataset. .. versionadded:: 0.7 """
[docs] def __init__( self, batch_size: int = 64, patch_size: int | tuple[int, int] = 304, num_workers: int = 0, **kwargs: Any, ) -> None: """Initialize a new TreeSatAIDataModule instance. Args: batch_size: Size of each mini-batch. patch_size: Size of each patch, either ``size`` or ``(height, width)``. num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.TreeSatAI`. """ super().__init__(TreeSatAI, batch_size, num_workers, **kwargs) self.patch_size = _to_tuple(patch_size) self.sensors = kwargs.get('sensors', TreeSatAI.all_sensors) self.train_aug = K.AugmentationSequential( K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), K.Resize(self.patch_size), data_keys=None, keepdim=True, ) self.aug = K.AugmentationSequential( K.Resize(self.patch_size), data_keys=None, keepdim=True )
[docs] def setup(self, stage: str) -> None: """Set up datasets. Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ # Convert 90-10 train-test split to 80-10-10 train-val-test split train_val_dataset = TreeSatAI(split='train', **self.kwargs) self.test_dataset = TreeSatAI(split='test', **self.kwargs) generator = torch.Generator().manual_seed(0) self.train_dataset, self.val_dataset = random_split( train_val_dataset, [len(train_val_dataset) - len(self.test_dataset), len(self.test_dataset)], generator=generator, )
[docs] def on_after_batch_transfer( self, batch: dict[str, Tensor], dataloader_idx: int ) -> dict[str, Tensor]: """Apply batch augmentations to the batch after it is transferred to the device. Args: batch: A batch of data that needs to be altered or augmented. dataloader_idx: The index of the dataloader to which the batch belongs. Returns: A batch of data. """ batch = super().on_after_batch_transfer(batch, dataloader_idx) images = [] for sensor in self.sensors: aug = K.Normalize(mean=means[sensor], std=stds[sensor], keepdim=True) batch[f'image_{sensor}'] = aug(batch[f'image_{sensor}']) images.append(batch[f'image_{sensor}']) batch['image'] = torch.cat(images, dim=1) return batch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources