Shortcuts

Source code for torchgeo.samplers.single

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""TorchGeo samplers."""

import abc
import random
from typing import Iterator, Optional, Tuple, Union

from torch.utils.data import Sampler

from torchgeo.datasets.geo import GeoDataset
from torchgeo.datasets.utils import BoundingBox

from .utils import _to_tuple, get_random_bounding_box

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Sampler.__module__ = "torch.utils.data"


class GeoSampler(Sampler[BoundingBox], abc.ABC):
    """Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`.

    Unlike PyTorch's :class:`~torch.utils.data.Sampler`, :class:`GeoSampler`
    returns enough geospatial information to uniquely index any
    :class:`~torchgeo.datasets.GeoDataset`. This includes things like latitude,
    longitude, height, width, projection, coordinate system, and time.
    """

[docs] @abc.abstractmethod def __iter__(self) -> Iterator[BoundingBox]: """Return the index of a dataset. Returns: (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """
class RandomGeoSampler(GeoSampler): """Samples elements from a region of interest randomly. This is particularly useful during training when you want to maximize the size of the dataset and return as many random :term:`chips <chip>` as possible. This sampler is not recommended for use with tile-based datasets. Use :class:`RandomBatchGeoSampler` instead. """
[docs] def __init__( self, dataset: GeoDataset, size: Union[Tuple[float, float], float], length: int, roi: Optional[BoundingBox] = None, ) -> None: """Initialize a new Sampler instance. The ``size`` argument can either be: * a single ``float`` - in which case the same value is used for the height and width dimension * a ``tuple`` of two floats - in which case, the first *float* is used for the height dimension, and the second *float* for the width dimension Args: dataset: dataset to index from size: dimensions of each :term:`patch` in units of CRS length: number of random samples to draw per epoch roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) """ self.index = dataset.index self.res = dataset.res self.size = _to_tuple(size) self.length = length if roi is None: roi = BoundingBox(*self.index.bounds) self.roi = roi self.hits = list(self.index.intersection(roi, objects=True))
[docs] def __iter__(self) -> Iterator[BoundingBox]: """Return the index of a dataset. Returns: (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """ for _ in range(len(self)): # Choose a random tile hit = random.choice(self.hits) bounds = BoundingBox(*hit.bounds) # Choose a random index within that tile bounding_box = get_random_bounding_box(bounds, self.size, self.res) yield bounding_box
[docs] def __len__(self) -> int: """Return the number of samples in a single epoch. Returns: length of the epoch """ return self.length
class GridGeoSampler(GeoSampler): """Samples elements in a grid-like fashion. This is particularly useful during evaluation when you want to make predictions for an entire region of interest. You want to minimize the amount of redundant computation by minimizing overlap between :term:`chips <chip>`. Usually the stride should be slightly smaller than the chip size such that each chip has some small overlap with surrounding chips. This is used to prevent `stitching artifacts <https://arxiv.org/abs/1805.12219>`_ when combining each prediction patch. The overlap between each chip (``chip_size - stride``) should be approximately equal to the `receptive field <https://distill.pub/2019/computing-receptive-fields/>`_ of the CNN. When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``dataset`` should be a non-tile-based dataset if possible. """
[docs] def __init__( self, dataset: GeoDataset, size: Union[Tuple[float, float], float], stride: Union[Tuple[float, float], float], roi: Optional[BoundingBox] = None, ) -> None: """Initialize a new Sampler instance. The ``size`` and ``stride`` arguments can either be: * a single ``float`` - in which case the same value is used for the height and width dimension * a ``tuple`` of two floats - in which case, the first *float* is used for the height dimension, and the second *float* for the width dimension Args: dataset: dataset to index from size: dimensions of each :term:`patch` in units of CRS stride: distance to skip between each patch roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) """ self.index = dataset.index self.size = _to_tuple(size) self.stride = _to_tuple(stride) if roi is None: roi = BoundingBox(*self.index.bounds) self.roi = roi self.hits = list(self.index.intersection(roi, objects=True)) self.length: int = 0 for hit in self.hits: bounds = BoundingBox(*hit.bounds) rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1 cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1 self.length += rows * cols
[docs] def __iter__(self) -> Iterator[BoundingBox]: """Return the index of a dataset. Returns: (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """ # For each tile... for hit in self.hits: bounds = BoundingBox(*hit.bounds) rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1 cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1 mint = bounds.mint maxt = bounds.maxt # For each row... for i in range(rows): miny = bounds.miny + i * self.stride[0] maxy = miny + self.size[0] # For each column... for j in range(cols): minx = bounds.minx + j * self.stride[1] maxx = minx + self.size[1] yield BoundingBox(minx, maxx, miny, maxy, mint, maxt)
[docs] def __len__(self) -> int: """Return the number of samples over the ROI. Returns: number of patches that will be sampled """ return self.length

© Copyright 2021, Microsoft Corporation. Revision c2b56148.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.1.1
Versions
latest
stable
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources