Shortcuts
Open in Colab Open on Planetary Computer

Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

Custom Raster Datasets

In this tutorial, we demonstrate how to create a custom RasterDataset for our own data. We will use the xView3 tiny dataset as an example.

Setup

[ ]:
%pip install torchgeo

Imports

[6]:
from pathlib import Path
from typing import Callable, Dict, Optional

import matplotlib.pyplot as plt
import torch
from rasterio.crs import CRS
from torch import Tensor
from torch.utils.data import DataLoader
from torchgeo.datasets import RasterDataset, stack_samples
from torchgeo.samplers import RandomGeoSampler

Custom RasterDataset

Unzipping the sample xView3 data from the tests folder

[7]:
from torchgeo.datasets.utils import extract_archive

data_root = Path('../../tests/data/xview3/')
extract_archive(str(data_root / 'sample_data.tar.gz'))

Now we have the xView3 tiny dataset downloaded and unzipped in our local directory. Note that all the test GeoTIFFs are comprised entirely of zeros. Any plotted image will appear to be entirely uniform.

xview3
├── 05bc615a9b0aaaaaa
│   ├── bathymetry.tif
│   ├── owiMask.tif
│   ├── owiWindDirection.tif
│   ├── owiWindQuality.tif
│   ├── owiWindSpeed.tif
│   ├── VH_dB.tif
│   └── VV_dB.tif

We would like to create a custom Dataset class based off of RasterDataset for this xView3 data. This will let us use torchgeo features such as: random sampling, merging other layers, fusing multiple datasets with UnionDataset and IntersectionDataset, and more. To do this, we can simply subclass RasterDataset and define a filename_glob property to select which files in a root directory will be included in the dataset. For example:

[8]:
class XView3Polarizations(RasterDataset):
    '''
    Load xView3 polarization data that ends in *_dB.tif
    '''

    filename_glob = "*_dB.tif"
[9]:
ds = XView3Polarizations(data_root)
sampler = RandomGeoSampler(ds, size=1024, length=5)
dl = DataLoader(ds, sampler=sampler, collate_fn=stack_samples)

for sample in dl:
    image = sample['image']
    print(image.shape)
    image = torch.squeeze(image)
    plt.imshow(image, cmap='bone', vmin=-35, vmax=-5)
torch.Size([1, 1, 102, 102])
torch.Size([1, 1, 102, 102])
torch.Size([1, 1, 102, 102])
torch.Size([1, 1, 102, 102])
torch.Size([1, 1, 102, 102])
../_images/tutorials_custom_raster_dataset_11_1.png
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources