Shortcuts

Source code for torchgeo.transforms.color

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""TorchGeo color transforms."""

from typing import Optional

from kornia.augmentation import IntensityAugmentationBase2D
from torch import Tensor


[docs]class RandomGrayscale(IntensityAugmentationBase2D): r"""Apply random transformation to grayscale according to a probability p value. There is no single agreed upon definition of grayscale for MSI. Some possibilities include: * Average of all bands: :math:`\frac{1}{C}` where :math:`C` is the number of spectral channels. * RGB-only bands: :math:`[0.299, 0.587, 0.114]` for the RGB channels, 0 for all other channels. * PCA: the first principal component across the spectral axis computed via PCA, minimizes redundant information. The weight vector you provide will be automatically rescaled to sum to 1 in order to avoid changing the intensity of the image. .. versionadded:: 0.5 """
[docs] def __init__( self, weights: Tensor, p: float = 0.1, same_on_batch: bool = False, keepdim: bool = False, ) -> None: """Initialize a new RandomGrayscale instance. Args: weights: Weights applied to each channel to compute a grayscale representation. Should be the same length as the number of channels. p: Probability of the image to be transformed to grayscale. same_on_batch: Apply the same transformation across the batch. keepdim: Whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). """ super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim) # Rescale to sum to 1 weights /= weights.sum() self.flags = {"weights": weights}
[docs] def apply_transform( self, input: Tensor, params: dict[str, Tensor], flags: dict[str, Tensor], transform: Optional[Tensor] = None, ) -> Tensor: """Apply the transform. Args: input: The input tensor. params: Generated parameters. flags: Static parameters. transform: The geometric transformation tensor. Returns: The augmented input. """ weights = flags["weights"][..., :, None, None].to(input.device) out = input * weights out = out.sum(dim=-3) out = out.unsqueeze(-3).expand(input.shape) return out

© Copyright 2021, Microsoft Corporation. Revision b9653beb.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
v0.5.2
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.1
v0.3.0
v0.2.1
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources