Shortcuts

Source code for torchgeo.datamodules.utils

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Common datamodule utilities."""

import math
from collections.abc import Iterable
from typing import Any

import numpy as np
import torch
from torch import Tensor


# Based on lightning_lite.utilities.exceptions
[docs]class MisconfigurationException(Exception): """Exception used to inform users of misuse with Lightning."""
def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]: """Custom collate fn for object detection and instance segmentation. Args: batch: list of sample dicts return by dataset Returns: batch dict output .. versionadded:: 0.6 """ output: dict[str, Any] = {} output['image'] = torch.stack([sample['image'] for sample in batch]) output['bbox_xyxy'] = [sample['bbox_xyxy'].float() for sample in batch] if 'label' in batch[0].keys(): output['label'] = [sample['label'] for sample in batch] else: output['label'] = [ torch.tensor([1] * len(sample['bbox_xyxy'])) for sample in batch ] if 'mask' in batch[0]: output['mask'] = [sample['mask'] for sample in batch] return output def group_shuffle_split( groups: Iterable[Any], train_size: float | None = None, test_size: float | None = None, random_state: int | None = None, ) -> tuple[list[int], list[int]]: """Method for performing a single group-wise shuffle split of data. Loosely based off of :class:`sklearn.model_selection.GroupShuffleSplit`. Args: groups: a sequence of group values used to split. Should be in the same order as the data you want to split. train_size: the proportion of groups to include in the train split. If None, then it is set to complement `test_size`. test_size: the proportion of groups to include in the test split (rounded up). If None, then it is set to complement `train_size`. random_state: controls the random splits (passed a seed to a numpy.random.Generator), set for reproducible splits. Returns: train_indices, test_indices Raises: ValueError if `train_size` and `test_size` do not sum to 1, aren't in the range (0,1), or are both None. ValueError if the number of training or testing groups turns out to be 0. """ if train_size is None and test_size is None: raise ValueError('You must specify `train_size`, `test_size`, or both.') if (train_size is not None and test_size is not None) and ( not math.isclose(train_size + test_size, 1) ): raise ValueError('`train_size` and `test_size` must sum to 1.') if train_size is None and test_size is not None: train_size = 1 - test_size if test_size is None and train_size is not None: test_size = 1 - train_size assert train_size is not None and test_size is not None if train_size <= 0 or train_size >= 1 or test_size <= 0 or test_size >= 1: raise ValueError('`train_size` and `test_size` must be in the range (0,1).') group_vals = sorted(set(groups)) n_groups = len(group_vals) n_test_groups = round(n_groups * test_size) n_train_groups = n_groups - n_test_groups if n_train_groups == 0 or n_test_groups == 0: raise ValueError( f'{n_groups} groups were found, however the current settings of ' + '`train_size` and `test_size` result in 0 training or testing groups.' ) generator = np.random.default_rng(seed=random_state) train_group_vals = set( generator.choice(group_vals, size=n_train_groups, replace=False) ) train_idxs = [] test_idxs = [] for i, group_val in enumerate(groups): if group_val in train_group_vals: train_idxs.append(i) else: test_idxs.append(i) return train_idxs, test_idxs

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources