Shortcuts

torchgeo.trainers

TorchGeo trainers.

class torchgeo.trainers.BYOLTask(**kwargs)[source]

Bases: LightningModule

Class for pre-training any PyTorch model using BYOL.

Supports any available Timm model as an architecture choice. To see a list of available pretrained models, you can do:

import timm
print(timm.list_models())
config_task()[source]

Configures the task based on kwargs parameters passed to the constructor.

__init__(**kwargs)[source]

Initialize a LightningModule for pre-training a model with BYOL.

Keyword Arguments:
  • in_channels – Number of input channels to model

  • backbone – Name of the timm model to use

  • weights – Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.

  • learning_rate – Learning rate for optimizer

  • learning_rate_schedule_patience – Patience for learning rate scheduler

Raises:

ValueError – if kwargs arguments are invalid

Changed in version 0.4: The backbone_name parameter was renamed to backbone. Change backbone support from torchvision.models to timm.

forward(*args, **kwargs)[source]

Forward pass of the model.

Parameters:

x – tensor of data to run through the model

Returns:

output from the model

Return type:

Any

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Returns:

a “lr dict” according to the pytorch lightning documentation – https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers

Return type:

Dict[str, Any]

training_step(*args, **kwargs)[source]

Compute and return the training loss.

Parameters:

batch – the output of your DataLoader

Returns:

training loss

Return type:

Tensor

validation_step(*args, **kwargs)[source]

Compute validation loss.

Parameters:

batch – the output of your DataLoader

test_step(*args, **kwargs)[source]

No-op, does nothing.

predict_step(*args, **kwargs)[source]

Compute and return the output embeddings of the image backbone.

Parameters:

batch – the output of your DataLoader

Returns:

image embeddings

Return type:

Tensor

class torchgeo.trainers.ClassificationTask(**kwargs)[source]

Bases: LightningModule

LightningModule for image classification.

Supports any available Timm model as an architecture choice. To see a list of available models, you can do:

import timm
print(timm.list_models())
config_model()[source]

Configures the model based on kwargs parameters passed to the constructor.

config_task()[source]

Configures the task based on kwargs parameters passed to the constructor.

__init__(**kwargs)[source]

Initialize the LightningModule with a model and loss function.

Keyword Arguments:
  • model – Name of the classification model use

  • loss – Name of the loss function, accepts ‘ce’, ‘jaccard’, or ‘focal’

  • weights – Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.

  • num_classes – Number of prediction classes

  • in_channels – Number of input channels to model

  • learning_rate – Learning rate for optimizer

  • learning_rate_schedule_patience – Patience for learning rate scheduler

Changed in version 0.4: The classification_model parameter was renamed to model.

forward(*args, **kwargs)[source]

Forward pass of the model.

Parameters:

x – input image

Returns:

prediction

Return type:

Any

training_step(*args, **kwargs)[source]

Compute and return the training loss.

Parameters:

batch – the output of your DataLoader

Returns:

training loss

Return type:

Tensor

training_epoch_end(outputs)[source]

Logs epoch-level training metrics.

Parameters:

outputs (Any) – list of items returned by training_step

validation_step(*args, **kwargs)[source]

Compute validation loss and log example predictions.

Parameters:
  • batch – the output of your DataLoader

  • batch_idx – the index of this batch

validation_epoch_end(outputs)[source]

Logs epoch level validation metrics.

Parameters:

outputs (Any) – list of items returned by validation_step

test_step(*args, **kwargs)[source]

Compute test loss.

Parameters:

batch – the output of your DataLoader

test_epoch_end(outputs)[source]

Logs epoch level test metrics.

Parameters:

outputs (Any) – list of items returned by test_step

predict_step(*args, **kwargs)[source]

Compute and return the predictions.

Parameters:

batch – the output of your DataLoader

Returns:

predicted softmax probabilities

Return type:

Tensor

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Returns:

a “lr dict” according to the pytorch lightning documentation – https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers

Return type:

Dict[str, Any]

class torchgeo.trainers.MultiLabelClassificationTask(**kwargs)[source]

Bases: ClassificationTask

LightningModule for multi-label image classification.

config_task()[source]

Configures the task based on kwargs parameters passed to the constructor.

__init__(**kwargs)[source]

Initialize the LightningModule with a model and loss function.

Keyword Arguments:
  • model – Name of the classification model use

  • loss – Name of the loss function, currently only supports ‘bce’

  • weights – Either “random” or ‘imagenet’

  • num_classes – Number of prediction classes

  • in_channels – Number of input channels to model

  • learning_rate – Learning rate for optimizer

  • learning_rate_schedule_patience – Patience for learning rate scheduler

Changed in version 0.4: The classification_model parameter was renamed to model.

training_step(*args, **kwargs)[source]

Compute and return the training loss.

Parameters:

batch – the output of your DataLoader

Returns:

training loss

Return type:

Tensor

validation_step(*args, **kwargs)[source]

Compute validation loss and log example predictions.

Parameters:
  • batch – the output of your DataLoader

  • batch_idx – the index of this batch

test_step(*args, **kwargs)[source]

Compute test loss.

Parameters:

batch – the output of your DataLoader

predict_step(*args, **kwargs)[source]

Compute and return the predictions.

Parameters:

batch – the output of your DataLoader

Returns:

predicted sigmoid probabilities

Return type:

Tensor

class torchgeo.trainers.ObjectDetectionTask(**kwargs)[source]

Bases: LightningModule

LightningModule for object detection of images.

Currently, supports Faster R-CNN, FCOS, and RetinaNet models from torchvision with one of the following backbone arguments:

['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
'resnext50_32x4d','resnext101_32x8d', 'wide_resnet50_2',
'wide_resnet101_2']

New in version 0.4.

config_task()[source]

Configures the task based on kwargs parameters passed to the constructor.

__init__(**kwargs)[source]

Initialize the LightningModule with a model and loss function.

Keyword Arguments:
  • model – Name of the detection model type to use

  • backbone – Name of the model backbone to use

  • in_channels – Number of channels in input image

  • num_classes – Number of semantic classes to predict

  • learning_rate – Learning rate for optimizer

  • learning_rate_schedule_patience – Patience for learning rate scheduler

Raises:

ValueError – if kwargs arguments are invalid

Changed in version 0.4: The detection_model parameter was renamed to model.

forward(*args, **kwargs)[source]

Forward pass of the model.

Parameters:

x – tensor of data to run through the model

Returns:

output from the model

Return type:

Any

training_step(*args, **kwargs)[source]

Compute and return the training loss.

Parameters:

batch – the output of your DataLoader

Returns:

training loss

Return type:

Tensor

validation_step(*args, **kwargs)[source]

Compute validation loss and log example predictions.

Parameters:
  • batch – the output of your DataLoader

  • batch_idx – the index of this batch

validation_epoch_end(outputs)[source]

Logs epoch level validation metrics.

Parameters:

outputs (Any) – list of items returned by validation_step

test_step(*args, **kwargs)[source]

Compute test MAP.

Parameters:

batch – the output of your DataLoader

test_epoch_end(outputs)[source]

Logs epoch level test metrics.

Parameters:

outputs (Any) – list of items returned by test_step

predict_step(*args, **kwargs)[source]

Compute and return the predictions.

Parameters:

batch – the output of your DataLoader

Returns:

list of predicted boxes, labels and scores

Return type:

List[Dict[str, Tensor]]

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Returns:

a “lr dict” according to the pytorch lightning documentation – https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers

Return type:

Dict[str, Any]

class torchgeo.trainers.RegressionTask(**kwargs)[source]

Bases: LightningModule

LightningModule for training models on regression datasets.

Supports any available Timm model as an architecture choice. To see a list of available models, you can do:

import timm
print(timm.list_models())
config_task()[source]

Configures the task based on kwargs parameters.

__init__(**kwargs)[source]

Initialize a new LightningModule for training simple regression models.

Keyword Arguments:
  • model – Name of the timm model to use

  • weights – Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.

  • num_outputs – Number of prediction outputs

  • in_channels – Number of input channels to model

  • learning_rate – Learning rate for optimizer

  • learning_rate_schedule_patience – Patience for learning rate scheduler

Changed in version 0.4: Change regression model support from torchvision.models to timm

forward(*args, **kwargs)[source]

Forward pass of the model.

Parameters:

x – tensor of data to run through the model

Returns:

output from the model

Return type:

Any

training_step(*args, **kwargs)[source]

Compute and return the training loss.

Parameters:

batch – the output of your DataLoader

Returns:

training loss

Return type:

Tensor

training_epoch_end(outputs)[source]

Logs epoch-level training metrics.

Parameters:

outputs (Any) – list of items returned by training_step

validation_step(*args, **kwargs)[source]

Compute validation loss and log example predictions.

Parameters:
  • batch – the output of your DataLoader

  • batch_idx – the index of this batch

validation_epoch_end(outputs)[source]

Logs epoch level validation metrics.

Parameters:

outputs (Any) – list of items returned by validation_step

test_step(*args, **kwargs)[source]

Compute test loss.

Parameters:

batch – the output of your DataLoader

test_epoch_end(outputs)[source]

Logs epoch level test metrics.

Parameters:

outputs (Any) – list of items returned by test_step

predict_step(*args, **kwargs)[source]

Compute and return the predictions.

Parameters:

batch – the output of your DataLoader

Returns:

predicted values

Return type:

Tensor

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Returns:

a “lr dict” according to the pytorch lightning documentation – https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers

Return type:

Dict[str, Any]

class torchgeo.trainers.SemanticSegmentationTask(**kwargs)[source]

Bases: LightningModule

LightningModule for semantic segmentation of images.

Supports Segmentation Models Pytorch as an architecture choice in combination with any of these TIMM backbones.

config_task()[source]

Configures the task based on kwargs parameters passed to the constructor.

__init__(**kwargs)[source]

Initialize the LightningModule with a model and loss function.

Keyword Arguments:
  • model – Name of the segmentation model type to use

  • backbone – Name of the timm backbone to use

  • weights – None or “imagenet” to use imagenet pretrained weights in the backbone

  • in_channels – Number of channels in input image

  • num_classes – Number of semantic classes to predict

  • loss – Name of the loss function, currently supports ‘ce’, ‘jaccard’ or ‘focal’ loss

  • ignore_index – Optional integer class index to ignore in the loss and metrics

  • learning_rate – Learning rate for optimizer

  • learning_rate_schedule_patience – Patience for learning rate scheduler

Raises:

ValueError – if kwargs arguments are invalid

Changed in version 0.3: The ignore_zeros parameter was renamed to ignore_index.

Changed in version 0.4: The segmentation_model parameter was renamed to model, encoder_name renamed to backbone, and encoder_weights to weights.

forward(*args, **kwargs)[source]

Forward pass of the model.

Parameters:

x – tensor of data to run through the model

Returns:

output from the model

Return type:

Any

training_step(*args, **kwargs)[source]

Compute and return the training loss.

Parameters:

batch – the output of your DataLoader

Returns:

training loss

Return type:

Tensor

training_epoch_end(outputs)[source]

Logs epoch level training metrics.

Parameters:

outputs (Any) – list of items returned by training_step

validation_step(*args, **kwargs)[source]

Compute validation loss and log example predictions.

Parameters:
  • batch – the output of your DataLoader

  • batch_idx – the index of this batch

validation_epoch_end(outputs)[source]

Logs epoch level validation metrics.

Parameters:

outputs (Any) – list of items returned by validation_step

test_step(*args, **kwargs)[source]

Compute test loss.

Parameters:

batch – the output of your DataLoader

test_epoch_end(outputs)[source]

Logs epoch level test metrics.

Parameters:

outputs (Any) – list of items returned by test_step

predict_step(*args, **kwargs)[source]

Compute and return the predictions.

By default, this will loop over images in a dataloader and aggregate predictions into a list. This may not be desirable if you have many images or large images which could cause out of memory errors. In this case it’s recommended to override this with a custom predict_step.

Parameters:

batch – the output of your DataLoader

Returns:

predicted softmax probabilities

Return type:

Tensor

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Returns:

a “lr dict” according to the pytorch lightning documentation – https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers

Return type:

Dict[str, Any]

Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.1
v0.3.0
v0.2.1
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources