Source code for torchgeo.datasets.cv4a_kenya_crop_type

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""CV4A Kenya Crop Type dataset."""

import csv
import os
from functools import lru_cache
from typing import Callable, Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor

from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive

# TODO: read geospatial information from stac.json files
[docs]class CV4AKenyaCropType(NonGeoDataset): """CV4A Kenya Crop Type dataset. Used in a competition in the Computer NonGeo for Agriculture (CV4A) workshop in ICLR 2020. See `this website <>`__ for dataset details. Consists of 4 tiles of Sentinel 2 imagery from 13 different points in time. Each tile has: * 13 multi-band observations throughout the growing season. Each observation includes 12 bands from Sentinel-2 L2A product, and a cloud probability layer. The twelve bands are [B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12] (refer to Sentinel-2 documentation for more information about the bands). The cloud probability layer is a product of the Sentinel-2 atmospheric correction algorithm (Sen2Cor) and provides an estimated cloud probability (0-100%) per pixel. All of the bands are mapped to a common 10 m spatial resolution grid. * A raster layer indicating the crop ID for the fields in the training set. * A raster layer indicating field IDs for the fields (both training and test sets). Fields with a crop ID 0 are the test fields. There are 3,286 fields in the train set and 1,402 fields in the test set. If you use this dataset in your research, please cite the following paper: * .. note:: This dataset requires the following additional library to be installed: * `radiant-mlhub <>`_ to download the imagery and labels from the Radiant Earth MLHub """ collection_ids = [ "ref_african_crops_kenya_02_labels", "ref_african_crops_kenya_02_source", ] image_meta = { "filename": "ref_african_crops_kenya_02_source.tar.gz", "md5": "9c2004782f6dc83abb1bf45ba4d0da46", } target_meta = { "filename": "ref_african_crops_kenya_02_labels.tar.gz", "md5": "93949abd0ae82ba564f5a933cefd8215", } tile_names = [ "ref_african_crops_kenya_02_tile_00", "ref_african_crops_kenya_02_tile_01", "ref_african_crops_kenya_02_tile_02", "ref_african_crops_kenya_02_tile_03", ] dates = [ "20190606", "20190701", "20190706", "20190711", "20190721", "20190805", "20190815", "20190825", "20190909", "20190919", "20190924", "20191004", "20191103", ] band_names = ( "B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B09", "B11", "B12", "CLD", ) rgb_bands = ["B04", "B03", "B02"] # Same for all tiles tile_height = 3035 tile_width = 2016
[docs] def __init__( self, root: str = "data", chip_size: int = 256, stride: int = 128, bands: tuple[str, ...] = band_names, transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, download: bool = False, api_key: Optional[str] = None, checksum: bool = False, verbose: bool = False, ) -> None: """Initialize a new CV4A Kenya Crop Type Dataset instance. Args: root: root directory where dataset can be found chip_size: size of chips stride: spacing between chips, if less than chip_size, then there will be overlap between chips bands: the subset of bands to load transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory api_key: a RadiantEarth MLHub API key to use for downloading the dataset checksum: if True, check the MD5 of the downloaded files (may be slow) verbose: if True, print messages when new tiles are loaded Raises: RuntimeError: if ``download=False`` but dataset is missing or checksum fails """ self._validate_bands(bands) self.root = root self.chip_size = chip_size self.stride = stride self.bands = bands self.transforms = transforms self.checksum = checksum self.verbose = verbose if download: self._download(api_key) if not self._check_integrity(): raise RuntimeError( "Dataset not found or corrupted. " + "You can use download=True to download it" ) # Calculate the indices that we will use over all tiles self.chips_metadata = [] for tile_index in range(len(self.tile_names)): for y in list(range(0, self.tile_height - self.chip_size, stride)) + [ self.tile_height - self.chip_size ]: for x in list(range(0, self.tile_width - self.chip_size, stride)) + [ self.tile_width - self.chip_size ]: self.chips_metadata.append((tile_index, y, x))
[docs] def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: index: index to return Returns: data, labels, field ids, and metadata at that index """ tile_index, y, x = self.chips_metadata[index] tile_name = self.tile_names[tile_index] img = self._load_all_image_tiles(tile_name, self.bands) labels, field_ids = self._load_label_tile(tile_name) img = img[:, :, y : y + self.chip_size, x : x + self.chip_size] labels = labels[y : y + self.chip_size, x : x + self.chip_size] field_ids = field_ids[y : y + self.chip_size, x : x + self.chip_size] sample = { "image": img, "mask": labels, "field_ids": field_ids, "tile_index": torch.tensor(tile_index), "x": torch.tensor(x), "y": torch.tensor(y), } if self.transforms is not None: sample = self.transforms(sample) return sample
[docs] def __len__(self) -> int: """Return the number of chips in the dataset. Returns: length of the dataset """ return len(self.chips_metadata)
@lru_cache(maxsize=128) def _load_label_tile(self, tile_name: str) -> tuple[Tensor, Tensor]: """Load a single _tile_ of labels and field_ids. Args: tile_name: name of tile to load Returns: tuple of labels and field ids Raises: AssertionError: if ``tile_name`` is invalid """ assert tile_name in self.tile_names if self.verbose: print(f"Loading labels/field_ids for {tile_name}") directory = os.path.join( self.root, "ref_african_crops_kenya_02_labels", tile_name + "_label" ) with, "labels.tif")) as img: array: "np.typing.NDArray[np.int_]" = np.array(img) labels = torch.from_numpy(array) with, "field_ids.tif")) as img: array = np.array(img) field_ids = torch.from_numpy(array) return (labels, field_ids) def _validate_bands(self, bands: tuple[str, ...]) -> None: """Validate list of bands. Args: bands: user-provided tuple of bands to load Raises: AssertionError: if ``bands`` is not a tuple ValueError: if an invalid band name is provided """ assert isinstance(bands, tuple), "The list of bands must be a tuple" for band in bands: if band not in self.band_names: raise ValueError(f"'{band}' is an invalid band name.") @lru_cache(maxsize=128) def _load_all_image_tiles( self, tile_name: str, bands: tuple[str, ...] = band_names ) -> Tensor: """Load all the imagery (across time) for a single _tile_. Optionally allows for subsetting of the bands that are loaded. Args: tile_name: name of tile to load bands: tuple of bands to load Returns imagery of shape (13, number of bands, 3035, 2016) where 13 is the number of points in time, 3035 is the tile height, and 2016 is the tile width Raises: AssertionError: if ``tile_name`` is invalid """ assert tile_name in self.tile_names if self.verbose: print(f"Loading all imagery for {tile_name}") img = torch.zeros( len(self.dates), len(bands), self.tile_height, self.tile_width, dtype=torch.float32, ) for date_index, date in enumerate(self.dates): img[date_index] = self._load_single_image_tile(tile_name, date, self.bands) return img @lru_cache(maxsize=128) def _load_single_image_tile( self, tile_name: str, date: str, bands: tuple[str, ...] ) -> Tensor: """Load the imagery for a single tile for a single date. Optionally allows for subsetting of the bands that are loaded. Args: tile_name: name of tile to load date: date of tile to load bands: bands to load Returns: array containing a single image tile Raises: AssertionError: if ``tile_name`` or ``date`` is invalid """ assert tile_name in self.tile_names assert date in self.dates if self.verbose: print(f"Loading imagery for {tile_name} at {date}") img = torch.zeros( len(bands), self.tile_height, self.tile_width, dtype=torch.float32 ) for band_index, band_name in enumerate(self.bands): filepath = os.path.join( self.root, "ref_african_crops_kenya_02_source", f"{tile_name}_{date}", f"{band_name}.tif", ) with as band_img: array: "np.typing.NDArray[np.int_]" = np.array(band_img) img[band_index] = torch.from_numpy(array) return img def _check_integrity(self) -> bool: """Check integrity of dataset. Returns: True if dataset files are found and/or MD5s match, else False """ images: bool = check_integrity( os.path.join(self.root, self.image_meta["filename"]), self.image_meta["md5"] if self.checksum else None, ) targets: bool = check_integrity( os.path.join(self.root, self.target_meta["filename"]), self.target_meta["md5"] if self.checksum else None, ) return images and targets
[docs] def get_splits(self) -> tuple[list[int], list[int]]: """Get the field_ids for the train/test splits from the dataset directory. Returns: list of training field_ids and list of testing field_ids """ train_field_ids = [] test_field_ids = [] splits_fn = os.path.join( self.root, "ref_african_crops_kenya_02_labels", "_common", "field_train_test_ids.csv", ) with open(splits_fn, newline="") as f: reader = csv.reader(f) # Skip header row next(reader) for row in reader: train_field_ids.append(int(row[0])) if row[1]: test_field_ids.append(int(row[1])) return train_field_ids, test_field_ids
def _download(self, api_key: Optional[str] = None) -> None: """Download the dataset and extract it. Args: api_key: a RadiantEarth MLHub API key to use for downloading the dataset Raises: RuntimeError: if download doesn't work correctly or checksums don't match """ if self._check_integrity(): print("Files already downloaded and verified") return for collection_id in self.collection_ids: download_radiant_mlhub_collection(collection_id, self.root, api_key) image_archive_path = os.path.join(self.root, self.image_meta["filename"]) target_archive_path = os.path.join(self.root, self.target_meta["filename"]) for fn in [image_archive_path, target_archive_path]: extract_archive(fn, self.root)
[docs] def plot( self, sample: dict[str, Tensor], show_titles: bool = True, time_step: int = 0, suptitle: Optional[str] = None, ) -> Figure: """Plot a sample from the dataset. Args: sample: a sample returned by :meth:`__getitem__` show_titles: flag indicating whether to show titles above each panel time_step: time step at which to access image, beginning with 0 suptitle: optional suptitle to use for figure Returns: a matplotlib Figure with the rendered sample .. versionadded:: 0.2 """ rgb_indices = [] for band in self.rgb_bands: if band in self.bands: rgb_indices.append(self.bands.index(band)) else: raise ValueError("Dataset doesn't contain some of the RGB bands") if "prediction" in sample: prediction = sample["prediction"] n_cols = 3 else: n_cols = 2 image, mask = sample["image"], sample["mask"] assert time_step <= image.shape[0] - 1, ( "The specified time step" " does not exist, image only contains {} time" " instances." ).format(image.shape[0]) image = image[time_step, rgb_indices, :, :] fig, axs = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, n_cols * 5)) axs[0].imshow(image.permute(1, 2, 0)) axs[0].axis("off") axs[1].imshow(mask) axs[1].axis("off") if "prediction" in sample: axs[2].imshow(prediction) axs[2].axis("off") if show_titles: axs[2].set_title("Prediction") if show_titles: axs[0].set_title("Image") axs[1].set_title("Mask") if suptitle is not None: plt.suptitle(suptitle) return fig

© Copyright 2021, Microsoft Corporation. Revision 6694cbd4.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources