Shortcuts

Source code for torchgeo.datasets.splits

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Dataset splitting utilities."""

from collections.abc import Sequence
from copy import deepcopy
from itertools import accumulate
from math import floor, isclose
from typing import Optional, Union, cast

from rtree.index import Index, Property
from torch import Generator, default_generator, randint, randperm

from ..datasets import GeoDataset
from .utils import BoundingBox

__all__ = (
    "random_bbox_assignment",
    "random_bbox_splitting",
    "random_grid_cell_assignment",
    "roi_split",
    "time_series_split",
)


def _fractions_to_lengths(fractions: Sequence[float], total: int) -> Sequence[int]:
    """Utility to divide a number into a list of integers according to fractions.

    Implementation based on :meth:`torch.utils.data.random_split`.

    Args:
        fractions: list of fractions
        total: total to be divided

    Returns:
        List of lengths.

    .. versionadded:: 0.5
    """
    lengths = [floor(frac * total) for frac in fractions]
    remainder = int(total - sum(lengths))
    # Add 1 to all the lengths in round-robin fashion until the remainder is 0
    for i in range(remainder):
        idx_to_add_at = i % len(lengths)
        lengths[idx_to_add_at] += 1
    return lengths


[docs]def random_bbox_assignment( dataset: GeoDataset, lengths: Sequence[float], generator: Optional[Generator] = default_generator, ) -> list[GeoDataset]: """Split a GeoDataset randomly assigning its index's BoundingBoxes. This function will go through each BoundingBox in the GeoDataset's index and randomly assign it to new GeoDatasets. Args: dataset: dataset to be split lengths: lengths or fractions of splits to be produced generator: (optional) generator used for the random permutation Returns A list of the subset datasets. .. versionadded:: 0.5 """ if not (isclose(sum(lengths), 1) or isclose(sum(lengths), len(dataset))): raise ValueError( "Sum of input lengths must equal 1 or the length of dataset's index." ) if any(n <= 0 for n in lengths): raise ValueError("All items in input lengths must be greater than 0.") if isclose(sum(lengths), 1): lengths = _fractions_to_lengths(lengths, len(dataset)) lengths = cast(Sequence[int], lengths) hits = list(dataset.index.intersection(dataset.index.bounds, objects=True)) hits = [hits[i] for i in randperm(sum(lengths), generator=generator)] new_indexes = [ Index(interleaved=False, properties=Property(dimension=3)) for _ in lengths ] for i, length in enumerate(lengths): for j in range(length): hit = hits.pop() new_indexes[i].insert(j, hit.bounds, hit.object) new_datasets = [] for index in new_indexes: ds = deepcopy(dataset) ds.index = index new_datasets.append(ds) return new_datasets
[docs]def random_bbox_splitting( dataset: GeoDataset, fractions: Sequence[float], generator: Optional[Generator] = default_generator, ) -> list[GeoDataset]: """Split a GeoDataset randomly splitting its index's BoundingBoxes. This function will go through each BoundingBox in the GeoDataset's index, split it in a random direction and assign the resulting BoundingBoxes to new GeoDatasets. Args: dataset: dataset to be split fractions: fractions of splits to be produced generator: generator used for the random permutation Returns A list of the subset datasets. .. versionadded:: 0.5 """ if not isclose(sum(fractions), 1): raise ValueError("Sum of input fractions must equal 1.") if any(n <= 0 for n in fractions): raise ValueError("All items in input fractions must be greater than 0.") new_indexes = [ Index(interleaved=False, properties=Property(dimension=3)) for _ in fractions ] for i, hit in enumerate( dataset.index.intersection(dataset.index.bounds, objects=True) ): box = BoundingBox(*hit.bounds) fraction_left = 1.0 # Randomly choose the split direction horizontal, flip = randint(0, 2, (2,), generator=generator) for j, fraction in enumerate(fractions): if fraction_left == fraction: # For the last fraction, no need to split again new_box = box elif flip: # new_box corresponds to fraction, box is the remainder that we might # split again in the next iteration. Each split is done according to # fraction wrt what's left box, new_box = box.split( (fraction_left - fraction) / fraction_left, horizontal ) else: # Same as above, but without flipping new_box, box = box.split(fraction / fraction_left, horizontal) new_indexes[j].insert(i, tuple(new_box), hit.object) fraction_left -= fraction horizontal = not horizontal new_datasets = [] for index in new_indexes: ds = deepcopy(dataset) ds.index = index new_datasets.append(ds) return new_datasets
[docs]def random_grid_cell_assignment( dataset: GeoDataset, fractions: Sequence[float], grid_size: int = 6, generator: Optional[Generator] = default_generator, ) -> list[GeoDataset]: """Overlays a grid over a GeoDataset and randomly assigns cells to new GeoDatasets. This function will go through each BoundingBox in the GeoDataset's index, overlay a grid over it, and randomly assign each cell to new GeoDatasets. Args: dataset: dataset to be split fractions: fractions of splits to be produced grid_size: number of rows and columns for the grid generator: generator used for the random permutation Returns A list of the subset datasets. .. versionadded:: 0.5 """ if not isclose(sum(fractions), 1): raise ValueError("Sum of input fractions must equal 1.") if any(n <= 0 for n in fractions): raise ValueError("All items in input fractions must be greater than 0.") if grid_size < 2: raise ValueError("Input grid_size must be greater than 1.") new_indexes = [ Index(interleaved=False, properties=Property(dimension=3)) for _ in fractions ] lengths = _fractions_to_lengths(fractions, len(dataset) * grid_size**2) cells = [] # Generate the grid's cells for each bbox in index for i, hit in enumerate( dataset.index.intersection(dataset.index.bounds, objects=True) ): minx, maxx, miny, maxy, mint, maxt = hit.bounds stridex = (maxx - minx) / grid_size stridey = (maxy - miny) / grid_size cells.extend( [ ( ( minx + x * stridex, minx + (x + 1) * stridex, miny + y * stridey, miny + (y + 1) * stridey, mint, maxt, ), hit.object, ) for x in range(grid_size) for y in range(grid_size) ] ) # Randomly assign cells to each new index cells = [cells[i] for i in randperm(len(cells), generator=generator)] for i, length in enumerate(lengths): for j in range(length): cell = cells.pop() new_indexes[i].insert(j, cell[0], cell[1]) new_datasets = [] for index in new_indexes: ds = deepcopy(dataset) ds.index = index new_datasets.append(ds) return new_datasets
[docs]def roi_split(dataset: GeoDataset, rois: Sequence[BoundingBox]) -> list[GeoDataset]: """Split a GeoDataset intersecting it with a ROI for each desired new GeoDataset. Args: dataset: dataset to be split rois: regions of interest of splits to be produced Returns A list of the subset datasets. .. versionadded:: 0.5 """ new_indexes = [ Index(interleaved=False, properties=Property(dimension=3)) for _ in rois ] for i, roi in enumerate(rois): if any(roi.intersects(x) and (roi & x).area > 0 for x in rois[i + 1 :]): raise ValueError("ROIs in input rois can't overlap.") j = 0 for hit in dataset.index.intersection(tuple(roi), objects=True): box = BoundingBox(*hit.bounds) new_box = box & roi if new_box.area > 0: new_indexes[i].insert(j, tuple(new_box), hit.object) j += 1 new_datasets = [] for index in new_indexes: ds = deepcopy(dataset) ds.index = index new_datasets.append(ds) return new_datasets
[docs]def time_series_split( dataset: GeoDataset, lengths: Sequence[Union[float, tuple[float, float]]] ) -> list[GeoDataset]: """Split a GeoDataset on its time dimension to create non-overlapping GeoDatasets. Args: dataset: dataset to be split lengths: lengths, fractions or pairs of timestamps (start, end) of splits to be produced Returns A list of the subset datasets. .. versionadded:: 0.5 """ minx, maxx, miny, maxy, mint, maxt = dataset.bounds totalt = maxt - mint if not all(isinstance(x, tuple) for x in lengths): lengths = cast(Sequence[float], lengths) if not (isclose(sum(lengths), 1) or isclose(sum(lengths), totalt)): raise ValueError( "Sum of input lengths must equal 1 or the dataset's time length." ) if any(n <= 0 for n in lengths): raise ValueError("All items in input lengths must be greater than 0.") if isclose(sum(lengths), 1): lengths = [totalt * f for f in lengths] lengths = [ (mint + offset - length, mint + offset) # type: ignore[operator] for offset, length in zip(accumulate(lengths), lengths) ] lengths = cast(Sequence[tuple[float, float]], lengths) new_indexes = [ Index(interleaved=False, properties=Property(dimension=3)) for _ in lengths ] _totalt = 0.0 for i, (start, end) in enumerate(lengths): if start >= end: raise ValueError( "Pairs of timestamps in lengths must have end greater than start." ) if start < mint or end > maxt: raise ValueError( "Pairs of timestamps in lengths can't be out of dataset's time bounds." ) if any(start < x < end or start < y < end for x, y in lengths[i + 1 :]): raise ValueError("Pairs of timestamps in lengths can't overlap.") # Remove one microsecond from each BoundingBox's maxt to avoid overlapping offset = 0 if i == len(lengths) - 1 else 1e-6 roi = BoundingBox(minx, maxx, miny, maxy, start, end - offset) j = 0 for hit in dataset.index.intersection(tuple(roi), objects=True): box = BoundingBox(*hit.bounds) new_box = box & roi if new_box.volume > 0: new_indexes[i].insert(j, tuple(new_box), hit.object) j += 1 _totalt += end - start if not isclose(_totalt, totalt): raise ValueError( "Pairs of timestamps in lengths must cover dataset's time bounds." ) new_datasets = [] for index in new_indexes: ds = deepcopy(dataset) ds.index = index new_datasets.append(ds) return new_datasets

© Copyright 2021, Microsoft Corporation. Revision b9653beb.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
v0.5.2
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.1
v0.3.0
v0.2.1
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources