# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""So2Sat dataset."""

import os
from typing import Any, Callable, Dict, Optional, cast

import numpy as np
import pytorch_lightning as pl
import torch
from torch import Tensor
from import DataLoader
from torchvision.transforms import Compose

from .geo import VisionDataset
from .utils import check_integrity

DataLoader.__module__ = ""

class So2Sat(VisionDataset):
    """So2Sat dataset.

    The `So2Sat <>`_ dataset consists of
    corresponding synthetic aperture radar and multispectral optical image data
    acquired by the Sentinel-1 and Sentinel-2 remote sensing satellites, and a
    corresponding local climate zones (LCZ) label. The dataset is distributed over
    42 cities across different continents and cultural regions of the world, and comes
    with a split into fully independent, non-overlapping training, validation,
    and test sets.

    This implementation focuses on the *2nd* version of the dataset as described in
    the author's github repository and hosted
    at This version is identical to the first
    version of the dataset but includes the test data. The splits are defined as

    * Training: 42 cities around the world
    * Validation: western half of 10 other cities covering 10 cultural zones
    * Testing: eastern half of the 10 other cities

    If you use this dataset in your research, please cite the following paper:


    .. note::

       This dataset can be automatically downloaded using the following bash script:

       .. code-block:: bash

          for split in training validation testing

       or manually downloaded from
       This download will likely take several hours.

    filenames = {
        "train": "training.h5",
        "validation": "validation.h5",
        "test": "testing.h5",
    md5s = {
        "train": "702bc6a9368ebff4542d791e53469244",
        "validation": "71cfa6795de3e22207229d06d6f8775d",
        "test": "e81426102b488623a723beab52b31a8a",

[docs] def __init__( self, root: str = "data", split: str = "train", transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, checksum: bool = False, ) -> None: """Initialize a new So2Sat dataset instance. Args: root: root directory where dataset can be found split: one of "train", "validation", or "test" transforms: a function/transform that takes input sample and its target as entry and returns a transformed version checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: AssertionError: if ``split`` argument is invalid RuntimeError: if data is not found in ``root``, or checksums don't match """ try: import h5py # noqa: F401 except ImportError: raise ImportError( "h5py is not installed and is required to use this dataset" ) assert split in ["train", "validation", "test"] self.root = root self.split = split self.transforms = transforms self.checksum = checksum if not self._check_integrity(): raise RuntimeError("Dataset not found or corrupted.") self.fn = os.path.join(self.root, self.filenames[split]) with h5py.File(self.fn, "r") as f: self.size = int(f["label"].shape[0])
[docs] def __getitem__(self, index: int) -> Dict[str, Tensor]: """Return an index within the dataset. Args: index: index to return Returns: data and label at that index """ import h5py with h5py.File(self.fn, "r") as f: s1 = f["sen1"][index].astype(np.float64) # convert from <f8 to float64 s2 = f["sen2"][index].astype(np.float64) # convert from <f8 to float64 label = int( # convert one-hot encoding to int64 then Python int f["label"][index].argmax() ) s1 = np.rollaxis(s1, 2, 0) # convert to CxHxW format s2 = np.rollaxis(s2, 2, 0) # convert to CxHxW format s1 = torch.from_numpy(s1) # type: ignore[attr-defined] s2 = torch.from_numpy(s2) # type: ignore[attr-defined] sample = { "image":[s1, s2]), # type: ignore[attr-defined] "label": label, } if self.transforms is not None: sample = self.transforms(sample) return sample
[docs] def __len__(self) -> int: """Return the number of data points in the dataset. Returns: length of the dataset """ return self.size
def _check_integrity(self) -> bool: """Check integrity of dataset. Returns: True if dataset files are found and/or MD5s match, else False """ for split_name, filename in self.filenames.items(): filepath = os.path.join(self.root, filename) md5 = self.md5s[split_name] if not check_integrity(filepath, md5 if self.checksum else None): return False return True class So2SatDataModule(pl.LightningDataModule): """LightningDataModule implementation for the So2Sat dataset. Uses the train/val/test splits from the dataset. """ band_means = torch.tensor( # type: ignore[attr-defined] [ -3.591224256609313e-05, -7.658561276843396e-06, 5.9373857475971184e-05, 2.5166231537121083e-05, 0.04420110659759328, 0.25761027084996196, 0.0007556743372573258, 0.0013503466830024448, 0.12375696117681859, 0.1092774636368323, 0.1010855203267882, 0.1142398616114001, 0.1592656692023089, 0.18147236008771792, 0.1745740312291377, 0.19501607349635292, 0.15428468872076637, 0.10905050699570007, ] ).reshape(18, 1, 1) band_stds = torch.tensor( # type: ignore[attr-defined] [ 0.17555201137417686, 0.17556463274968204, 0.45998793417834255, 0.455988755730148, 2.8559909213125763, 8.324800606439833, 2.4498757382563103, 1.4647352984509094, 0.03958795985905458, 0.047778262752410296, 0.06636616706371974, 0.06358874912497474, 0.07744387147984592, 0.09101635085921553, 0.09218466562387101, 0.10164581233948201, 0.09991773043519253, 0.08780632509122865, ] ).reshape(18, 1, 1) # this reorders the bands to put S2 RGB first, then remainder of S2, then S1 reindex_to_rgb_first = [ 10, 9, 8, 11, 12, 13, 14, 15, 16, 17, # 0, # 1, # 2, # 3, # 4, # 5, # 6, # 7, ]
[docs] def __init__( self, root_dir: str, batch_size: int = 64, num_workers: int = 0, bands: str = "rgb", unsupervised_mode: bool = False, **kwargs: Any, ) -> None: """Initialize a LightningDataModule for So2Sat based DataLoaders. Args: root_dir: The ``root`` arugment to pass to the So2Sat Dataset classes batch_size: The batch size to use in all created DataLoaders num_workers: The number of workers to use in all created DataLoaders bands: Either "rgb" or "s2" unsupervised_mode: Makes the train dataloader return imagery from the train, val, and test sets """ super().__init__() # type: ignore[no-untyped-call] self.root_dir = root_dir self.batch_size = batch_size self.num_workers = num_workers self.bands = bands self.unsupervised_mode = unsupervised_mode
[docs] def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset. Args: sample: dictionary containing image Returns: preprocessed sample """ # sample["image"] = (sample["image"] - self.band_means) / self.band_stds sample["image"] = sample["image"].float() sample["image"] = sample["image"][self.reindex_to_rgb_first, :, :] if self.bands == "rgb": sample["image"] = sample["image"][:3, :, :] return sample
[docs] def prepare_data(self) -> None: """Make sure that the dataset is downloaded. This method is only called once per run. """ So2Sat(self.root_dir, checksum=False)
[docs] def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. This method is called once per GPU per run. Args: stage: stage to set up """ train_transforms = Compose([self.preprocess]) val_test_transforms = self.preprocess if not self.unsupervised_mode: self.train_dataset = So2Sat( self.root_dir, split="train", transforms=train_transforms ) self.val_dataset = So2Sat( self.root_dir, split="validation", transforms=val_test_transforms ) self.test_dataset = So2Sat( self.root_dir, split="test", transforms=val_test_transforms ) else: temp_train = So2Sat( self.root_dir, split="train", transforms=train_transforms ) self.val_dataset = So2Sat( self.root_dir, split="validation", transforms=train_transforms ) self.test_dataset = So2Sat( self.root_dir, split="test", transforms=train_transforms ) self.train_dataset = cast( So2Sat, temp_train + self.val_dataset + self.test_dataset )
[docs] def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training. Returns: training data loader """ return DataLoader( self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, )
[docs] def val_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for validation. Returns: validation data loader """ return DataLoader( self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, )
[docs] def test_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for testing. Returns: testing data loader """ return DataLoader( self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, )

