Source code for torchgeo.samplers.single
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""TorchGeo samplers."""
import abc
import random
from typing import Iterator, Optional, Tuple, Union
from torch.utils.data import Sampler
from torchgeo.datasets.geo import GeoDataset
from torchgeo.datasets.utils import BoundingBox
from .utils import _to_tuple, get_random_bounding_box
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Sampler.__module__ = "torch.utils.data"
class GeoSampler(Sampler[BoundingBox], abc.ABC):
"""Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`.
Unlike PyTorch's :class:`~torch.utils.data.Sampler`, :class:`GeoSampler`
returns enough geospatial information to uniquely index any
:class:`~torchgeo.datasets.GeoDataset`. This includes things like latitude,
longitude, height, width, projection, coordinate system, and time.
"""
[docs] @abc.abstractmethod
def __iter__(self) -> Iterator[BoundingBox]:
"""Return the index of a dataset.
Returns:
(minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
"""
class RandomGeoSampler(GeoSampler):
"""Samples elements from a region of interest randomly.
This is particularly useful during training when you want to maximize the size of
the dataset and return as many random :term:`chips <chip>` as possible.
This sampler is not recommended for use with tile-based datasets. Use
:class:`RandomBatchGeoSampler` instead.
"""
[docs] def __init__(
self,
dataset: GeoDataset,
size: Union[Tuple[float, float], float],
length: int,
roi: Optional[BoundingBox] = None,
) -> None:
"""Initialize a new Sampler instance.
The ``size`` argument can either be:
* a single ``float`` - in which case the same value is used for the height and
width dimension
* a ``tuple`` of two floats - in which case, the first *float* is used for the
height dimension, and the second *float* for the width dimension
Args:
dataset: dataset to index from
size: dimensions of each :term:`patch` in units of CRS
length: number of random samples to draw per epoch
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
"""
self.index = dataset.index
self.res = dataset.res
self.size = _to_tuple(size)
self.length = length
if roi is None:
roi = BoundingBox(*self.index.bounds)
self.roi = roi
self.hits = list(self.index.intersection(roi, objects=True))
[docs] def __iter__(self) -> Iterator[BoundingBox]:
"""Return the index of a dataset.
Returns:
(minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
"""
for _ in range(len(self)):
# Choose a random tile
hit = random.choice(self.hits)
bounds = BoundingBox(*hit.bounds)
# Choose a random index within that tile
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
yield bounding_box
[docs] def __len__(self) -> int:
"""Return the number of samples in a single epoch.
Returns:
length of the epoch
"""
return self.length
class GridGeoSampler(GeoSampler):
"""Samples elements in a grid-like fashion.
This is particularly useful during evaluation when you want to make predictions for
an entire region of interest. You want to minimize the amount of redundant
computation by minimizing overlap between :term:`chips <chip>`.
Usually the stride should be slightly smaller than the chip size such that each chip
has some small overlap with surrounding chips. This is used to prevent `stitching
artifacts <https://arxiv.org/abs/1805.12219>`_ when combining each prediction patch.
The overlap between each chip (``chip_size - stride``) should be approximately equal
to the `receptive field <https://distill.pub/2019/computing-receptive-fields/>`_ of
the CNN.
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``dataset`` should be
a non-tile-based dataset if possible.
"""
[docs] def __init__(
self,
dataset: GeoDataset,
size: Union[Tuple[float, float], float],
stride: Union[Tuple[float, float], float],
roi: Optional[BoundingBox] = None,
) -> None:
"""Initialize a new Sampler instance.
The ``size`` and ``stride`` arguments can either be:
* a single ``float`` - in which case the same value is used for the height and
width dimension
* a ``tuple`` of two floats - in which case, the first *float* is used for the
height dimension, and the second *float* for the width dimension
Args:
dataset: dataset to index from
size: dimensions of each :term:`patch` in units of CRS
stride: distance to skip between each patch
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
"""
self.index = dataset.index
self.size = _to_tuple(size)
self.stride = _to_tuple(stride)
if roi is None:
roi = BoundingBox(*self.index.bounds)
self.roi = roi
self.hits = list(self.index.intersection(roi, objects=True))
self.length: int = 0
for hit in self.hits:
bounds = BoundingBox(*hit.bounds)
rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
self.length += rows * cols
[docs] def __iter__(self) -> Iterator[BoundingBox]:
"""Return the index of a dataset.
Returns:
(minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
"""
# For each tile...
for hit in self.hits:
bounds = BoundingBox(*hit.bounds)
rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
mint = bounds.mint
maxt = bounds.maxt
# For each row...
for i in range(rows):
miny = bounds.miny + i * self.stride[0]
maxy = miny + self.size[0]
# For each column...
for j in range(cols):
minx = bounds.minx + j * self.stride[1]
maxx = minx + self.size[1]
yield BoundingBox(minx, maxx, miny, maxy, mint, maxt)
[docs] def __len__(self) -> int:
"""Return the number of samples over the ROI.
Returns:
number of patches that will be sampled
"""
return self.length