
Source code for torchgeo.datasets.south_africa_crop_type

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""South Africa Crop Type Competition Dataset."""

import os
import re
from import Callable, Iterable
from typing import Any, cast

import matplotlib.pyplot as plt
import torch
from matplotlib.figure import Figure
from import CRS
from torch import Tensor

from .errors import RGBBandsMissingError
from .geo import RasterDataset
from .utils import BoundingBox

[docs]class SouthAfricaCropType(RasterDataset): """South Africa Crop Type Challenge dataset. The `South Africa Crop Type Challenge <>`__ dataset includes satellite imagery from Sentinel-1 and Sentinel-2 and labels for crop type that were collected by aerial and vehicle survey from May 2017 to March 2018. Data was provided by the Western Cape Department of Agriculture and is available via the Radiant Earth Foundation. For each field id the dataset contains time series imagery and a single label mask. Since TorchGeo does not yet support timeseries datasets, the first available imagery in July will be returned for each field. Note that the dates for S1 and S2 imagery for a given field are not guaranteed to be the same. Due to this date mismatch only S1 or S2 bands may be queried at a time, a mix of both is not supported. Each pixel in the label contains an integer field number and crop type class. Dataset format: * images are 2-band Sentinel 1 and 12-band Sentinel-2 data with a cloud mask * masks are tiff images with unique values representing the class and field id. Dataset classes: 0. No Data 1. Lucerne/Medics 2. Planted pastures (perennial) 3. Fallow 4. Wine grapes 5. Weeds 6. Small grain grazing 7. Wheat 8. Canola 9. Rooibos If you use this dataset in your research, please cite the following dataset: * Western Cape Department of Agriculture, Radiant Earth Foundation (2021) "Crop Type Classification Dataset for Western Cape, South Africa", Version 1.0, Radiant MLHub, .. versionadded:: 0.6 """ s1_regex = r""" ^(?P<field_id>\d+) _(?P<date>\d{4}_07_\d{2}) _(?P<band>[VH]{2}) _10m""" s2_regex = r""" ^(?P<field_id>\d+) _(?P<date>\d{4}_07_\d{2}) _(?P<band>(B[0-9A-Z]{2})) _10m""" filename_regex = s2_regex date_format = '%Y_%m_%d' rgb_bands = ['B04', 'B03', 'B02'] s1_bands = ['VH', 'VV'] s2_bands = [ 'B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12', ] all_bands: list[str] = s1_bands + s2_bands cmap = { 0: (0, 0, 0, 255), 1: (255, 211, 0, 255), 2: (255, 37, 37, 255), 3: (0, 168, 226, 255), 4: (255, 158, 9, 255), 5: (37, 111, 0, 255), 6: (255, 255, 0, 255), 7: (222, 166, 9, 255), 8: (111, 166, 0, 255), 9: (0, 175, 73, 255), }
[docs] def __init__( self, paths: str | Iterable[str] = 'data', crs: CRS | None = None, classes: list[int] = list(cmap.keys()), bands: list[str] = s2_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, ) -> None: """Initialize a new South Africa Crop Type dataset instance. Args: paths: paths directory where dataset can be found crs: coordinate reference system to be used classes: crop type classes to be included bands: the subset of bands to load transforms: a function/transform that takes input sample and its target as entry and returns a transformed version Raises: DatasetNotFoundError: If dataset is not found and *download* is False. """ assert ( set(classes) <= self.cmap.keys() ), f'Only the following classes are valid: {list(self.cmap.keys())}.' assert 0 in classes, 'Classes must include the background class: 0' self.paths = paths self.classes = classes self.ordinal_map = torch.zeros(max(self.cmap.keys()) + 1, dtype=self.dtype) self.ordinal_cmap = torch.zeros((len(self.classes), 4), dtype=torch.uint8) if set(bands).issubset(set(self.s1_bands)): self.filename_regex = self.s1_regex super().__init__(paths=paths, crs=crs, bands=bands, transforms=transforms) # Map chosen classes to ordinal numbers, all others mapped to background class for v, k in enumerate(self.classes): self.ordinal_map[k] = v self.ordinal_cmap[v] = torch.tensor(self.cmap[k])
[docs] def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Return an index within the dataset. Args: query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index Returns: data and labels at that index """ assert isinstance(self.paths, str) # Get all files matching the given query hits = self.index.intersection(tuple(query), objects=True) filepaths = cast(list[str], [hit.object for hit in hits]) if not filepaths: raise IndexError( f'query: {query} not found in index with bounds: {self.bounds}' ) data_list: list[Tensor] = [] filename_regex = re.compile(self.filename_regex, re.VERBOSE) # Loop through matched filepaths and find all unique field ids field_ids: list[str] = [] # Store date in July for s1 and s2 we want to use for each sample imagery_dates: dict[str, dict[str, str]] = {} for filepath in filepaths: filename = os.path.basename(filepath) match = re.match(filename_regex, filename) if match: field_id ='field_id') date ='date') band ='band') band_type = 's1' if band in self.s1_bands else 's2' if field_id not in field_ids: field_ids.append(field_id) imagery_dates[field_id] = {'s1': '', 's2': ''} if ( date.split('_')[1] == '07' and not imagery_dates[field_id][band_type] ): imagery_dates[field_id][band_type] = date # Create Tensors for each band using stored dates for band in self.bands: band_type = 's1' if band in self.s1_bands else 's2' band_filepaths = [] for field_id in field_ids: date = imagery_dates[field_id][band_type] filepath = os.path.join( self.paths, 'train', 'imagery', band_type, field_id, date, f'{field_id}_{date}_{band}_10m.tif', ) band_filepaths.append(filepath) data_list.append(self._merge_files(band_filepaths, query)) image = # Add labels for each field mask_filepaths: list[str] = [] for field_id in field_ids: file_path = filepath = os.path.join( self.paths, 'train', 'labels', f'{field_id}.tif' ) mask_filepaths.append(file_path) mask = self._merge_files(mask_filepaths, query) sample = { 'crs':, 'bbox': query, 'image': image.float(), 'mask': mask.long(), } if self.transforms is not None: sample = self.transforms(sample) return sample
[docs] def plot( self, sample: dict[str, Tensor], show_titles: bool = True, suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. Args: sample: a sample returned by :meth:`__getitem__` show_titles: flag indicating whether to show titles above each panel suptitle: optional string to use as a suptitle Returns: a matplotlib Figure with the rendered sample Raises: RGBBandsMissingError: If *bands* does not include all RGB bands. """ rgb_indices = [] for band in self.rgb_bands: if band in self.bands: rgb_indices.append(self.bands.index(band)) else: raise RGBBandsMissingError() image = sample['image'][rgb_indices].permute(1, 2, 0) image = (image - image.min()) / (image.max() - image.min()) mask = sample['mask'].squeeze() ncols = 2 showing_prediction = 'prediction' in sample if showing_prediction: pred = sample['prediction'].squeeze() ncols += 1 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * 4, 4)) axs[0].imshow(image) axs[0].axis('off') axs[1].imshow(self.ordinal_cmap[mask], interpolation='none') axs[1].axis('off') if show_titles: axs[0].set_title('Image') axs[1].set_title('Mask') if showing_prediction: axs[2].imshow(pred) axs[2].axis('off') if show_titles: axs[2].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) return fig

© Copyright 2021, Microsoft Corporation. Revision 94bd5c76.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources