Shortcuts

Source code for torchgeo.datasets.nwpu

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""NWPU VHR-10 dataset."""

import os
from typing import Any, Callable, Dict, Optional

import numpy as np
import torch
from PIL import Image
from torch import Tensor

from .geo import VisionDataset
from .utils import check_integrity, download_and_extract_archive, download_url


class VHR10(VisionDataset):
    """NWPU VHR-10 dataset.

    Northwestern Polytechnical University (NWPU) very-high-resolution ten-class (VHR-10)
    remote sensing image dataset.

    Consists of 800 VHR optical remote sensing images, where 715 color images were
    acquired from Google Earth with the spatial resolution ranging from 0.5 to 2 m,
    and 85 pansharpened color infrared (CIR) images were acquired from Vaihingen data
    with a spatial resolution of 0.08 m.

    The data set is divided into two sets:

    * Positive image set (650 images) which contains at least one target in an image
    * Negative image set (150 images) does not contain any targets

    The positive image set consists of objects from ten classes:

    1. Airplanes (757)
    2. Ships (302)
    3. Storage tanks (655)
    4. Baseball diamonds (390)
    5. Tennis courts (524)
    6. Basketball courts (159)
    7. Ground track fields (163)
    8. Harbors (224)
    9. Bridges (124)
    10. Vehicles (477)

    Includes object detection bounding boxes from original paper and instance
    segmentation masks from follow-up publications. If you use this dataset in your
    research, please cite the following papers:

    * https://doi.org/10.1016/j.isprsjprs.2014.10.002
    * https://doi.org/10.1109/IGARSS.2019.8898573
    * https://doi.org/10.3390/rs12060989

    .. note::

       This dataset requires the following additional libraries to be installed:

       * `pycocotools <https://pypi.org/project/pycocotools/>`_ to load the
         ``annotations.json`` file for the "positive" image set
       * `rarfile <https://pypi.org/project/rarfile/>`_ to extract the dataset,
         which is stored in a RAR file
    """

    image_meta = {
        "url": "https://drive.google.com/file/d/1--foZ3dV5OCsqXQXT84UeKtrAqc5CkAE",
        "filename": "NWPU VHR-10 dataset.rar",
        "md5": "d30a7ff99d92123ebb0b3a14d9102081",
    }
    target_meta = {
        "url": (
            "https://raw.githubusercontent.com/chaozhong2010/VHR-10_dataset_coco/"
            "master/NWPU%20VHR-10_dataset_coco/annotations.json"
        ),
        "filename": "annotations.json",
        "md5": "7c76ec50c17a61bb0514050d20f22c08",
    }

[docs] def __init__( self, root: str = "data", split: str = "positive", transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, download: bool = False, checksum: bool = False, ) -> None: """Initialize a new VHR-10 dataset instance. Args: root: root directory where dataset can be found split: one of "postive" or "negative" transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: AssertionError: if ``split`` argument is invalid RuntimeError: if ``download=False`` and data is not found, or checksums don't match """ assert split in ["positive", "negative"] self.root = root self.split = split self.transforms = transforms self.checksum = checksum if download: self._download() if not self._check_integrity(): raise RuntimeError( "Dataset not found or corrupted. " + "You can use download=True to download it" ) if split == "positive": # Must be installed to parse annotations file try: from pycocotools.coco import COCO # noqa: F401 except ImportError: raise ImportError( "pycocotools is not installed and is required to use this dataset" ) self.coco = COCO( os.path.join( self.root, "NWPU VHR-10 dataset", self.target_meta["filename"] ) )
[docs] def __getitem__(self, index: int) -> Dict[str, Any]: """Return an index within the dataset. Args: index: index to return Returns: data and label at that index """ id_ = index % len(self) + 1 sample = {"image": self._load_image(id_), "label": self._load_target(id_)} if self.transforms is not None: sample = self.transforms(sample) return sample
[docs] def __len__(self) -> int: """Return the number of data points in the dataset. Returns: length of the dataset """ if self.split == "positive": return 650 else: return 150
def _load_image(self, id_: int) -> Tensor: """Load a single image. Args: id_: unique ID of the image Returns: the image """ filename = os.path.join( self.root, "NWPU VHR-10 dataset", self.split + " image set", f"{id_:03d}.jpg", ) with Image.open(filename) as img: array = np.array(img) tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined] # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor def _load_target(self, id_: int) -> Dict[str, Any]: """Load the annotations for a single image. Args: id_: unique ID of the image Returns: the annotations """ # Images in the "negative" image set have no annotations annot = [] if self.split == "positive": annot = self.coco.loadAnns(self.coco.getAnnIds(id_)) target = dict(image_id=id_, annotations=annot) return target def _check_integrity(self) -> bool: """Check integrity of dataset. Returns: True if dataset files are found and/or MD5s match, else False """ image: bool = check_integrity( os.path.join(self.root, self.image_meta["filename"]), self.image_meta["md5"] if self.checksum else None, ) # Annotations only needed for "positive" image set target = True if self.split == "positive": target = check_integrity( os.path.join( self.root, "NWPU VHR-10 dataset", self.target_meta["filename"] ), self.target_meta["md5"] if self.checksum else None, ) return image and target def _download(self) -> None: """Download the dataset and extract it.""" if self._check_integrity(): print("Files already downloaded and verified") return # Download images download_and_extract_archive( self.image_meta["url"], self.root, filename=self.image_meta["filename"], md5=self.image_meta["md5"] if self.checksum else None, ) # Annotations only needed for "positive" image set if self.split == "positive": # Download annotations download_url( self.target_meta["url"], os.path.join(self.root, "NWPU VHR-10 dataset"), self.target_meta["filename"], self.target_meta["md5"] if self.checksum else None, )

© Copyright 2021, Microsoft Corporation. Revision c2b56148.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.1.1
Versions
latest
stable
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources