Shortcuts

Source code for torchgeo.transforms.transforms

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""TorchGeo transforms."""

from typing import Any, Optional, Union

import kornia.augmentation as K
import torch
from einops import rearrange
from kornia.geometry import crop_by_indices
from torch import Tensor
from torch.nn.modules import Module


# TODO: contribute these to Kornia and delete this file
[docs]class AugmentationSequential(Module): """Wrapper around kornia AugmentationSequential to handle input dicts. .. deprecated:: 0.4 Use :class:`kornia.augmentation.container.AugmentationSequential` instead. """
[docs] def __init__( self, *args: Union[K.base._AugmentationBase, K.ImageSequential], data_keys: list[str], **kwargs: Any, ) -> None: """Initialize a new augmentation sequential instance. Args: *args: Sequence of kornia augmentations data_keys: List of inputs to augment (e.g., ["image", "mask", "boxes"]) **kwargs: Keyword arguments passed to ``K.AugmentationSequential`` .. versionadded:: 0.5 The ``**kwargs`` parameter. """ super().__init__() self.data_keys = data_keys keys: list[str] = [] for key in data_keys: if key == "image": keys.append("input") elif key == "boxes": keys.append("bbox") else: keys.append(key) self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs)
[docs] def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Perform augmentations and update data dict. Args: batch: the input Returns: the augmented input """ # Kornia augmentations require all inputs to be float dtype = {} for key in self.data_keys: dtype[key] = batch[key].dtype batch[key] = batch[key].float() # Kornia requires masks to have a channel dimension if "mask" in batch and len(batch["mask"].shape) == 3: batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") inputs = [batch[k] for k in self.data_keys] outputs_list: Union[Tensor, list[Tensor]] = self.augs(*inputs) outputs_list = ( outputs_list if isinstance(outputs_list, list) else [outputs_list] ) outputs: dict[str, Tensor] = { k: v for k, v in zip(self.data_keys, outputs_list) } batch.update(outputs) # Convert all inputs back to their previous dtype for key in self.data_keys: batch[key] = batch[key].to(dtype[key]) # Torchmetrics does not support masks with a channel dimension if "mask" in batch and batch["mask"].shape[1] == 1: batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") return batch
class _RandomNCrop(K.GeometricAugmentationBase2D): """Take N random crops of a tensor.""" def __init__(self, size: tuple[int, int], num: int) -> None: """Initialize a new _RandomNCrop instance. Args: size: desired output size (out_h, out_w) of the crop num: number of crops to take """ super().__init__(p=1) self._param_generator: _NCropGenerator = _NCropGenerator(size, num) self.flags = {"size": size, "num": num} def compute_transformation( self, input: Tensor, params: dict[str, Tensor], flags: dict[str, Any] ) -> Tensor: """Compute the transformation. Args: input: the input tensor params: generated parameters flags: static parameters Returns: the transformation """ out: Tensor = self.identity_matrix(input) return out def apply_transform( self, input: Tensor, params: dict[str, Tensor], flags: dict[str, Any], transform: Optional[Tensor] = None, ) -> Tensor: """Apply the transform. Args: input: the input tensor params: generated parameters flags: static parameters transform: the geometric transformation tensor Returns: the augmented input """ out = [] for i in range(flags["num"]): out.append(crop_by_indices(input, params["src"][i], flags["size"])) return torch.cat(out) class _NCropGenerator(K.random_generator.CropGenerator): """Generate N random crops.""" def __init__(self, size: Union[tuple[int, int], Tensor], num: int) -> None: """Initialize a new _NCropGenerator instance. Args: size: desired output size (out_h, out_w) of the crop num: number of crops to generate """ super().__init__(size) self.num = num def forward( self, batch_shape: tuple[int, ...], same_on_batch: bool = False ) -> dict[str, Tensor]: """Generate the crops. Args: batch_shape: input size (b, c?, in_h, in_w) same_on_batch: apply the same transformation across the batch Returns: the randomly generated parameters """ out = [] for _ in range(self.num): out.append(super().forward(batch_shape, same_on_batch)) return { "src": torch.stack([x["src"] for x in out]), "dst": torch.stack([x["dst"] for x in out]), "input_size": out[0]["input_size"], "output_size": out[0]["output_size"], }

© Copyright 2021, Microsoft Corporation. Revision b9653beb.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
v0.5.2
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.1
v0.3.0
v0.2.1
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources