torchgeo.trainers¶
TorchGeo trainers.
- class torchgeo.trainers.ClassificationTask(model='resnet50', weights=None, in_channels=3, num_classes=1000, loss='ce', class_weights=None, lr=0.001, patience=10, freeze_backbone=False)[source]¶
Bases:
BaseTask
Image classification.
- __init__(model='resnet50', weights=None, in_channels=3, num_classes=1000, loss='ce', class_weights=None, lr=0.001, patience=10, freeze_backbone=False)[source]¶
Initialize a new ClassificationTask instance.
- Parameters:
weights (torchvision.models._api.WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.
in_channels (int) – Number of input channels to model.
num_classes (int) – Number of prediction classes.
loss (str) – One of ‘ce’, ‘bce’, ‘jaccard’, or ‘focal’.
class_weights (torch.Tensor | None) – Optional rescaling weight given to each class and used with ‘ce’ loss.
lr (float) – Learning rate for optimizer.
patience (int) – Patience for learning rate scheduler.
freeze_backbone (bool) – Freeze the backbone network to linear probe the classifier head.
Changed in version 0.4: classification_model was renamed to model.
New in version 0.5: The class_weights and freeze_backbone parameters.
Changed in version 0.5: learning_rate and learning_rate_schedule_patience were renamed to lr and patience.
- configure_losses()[source]¶
Initialize the loss criterion.
- Raises:
ValueError – If loss is invalid.
- configure_metrics()[source]¶
Initialize the performance metrics.
MulticlassAccuracy
: The number of true positives divided by the dataset size. Both overall accuracy (OA) using ‘micro’ averaging and average accuracy (AA) using ‘macro’ averaging are reported. Higher values are better.MulticlassJaccardIndex
: Intersection over union (IoU). Uses ‘macro’ averaging. Higher valuers are better.MulticlassFBetaScore
: F1 score. The harmonic mean of precision and recall. Uses ‘micro’ averaging. Higher values are better.
Note
‘Micro’ averaging suits overall performance evaluation but may not reflect minority class accuracy.
‘Macro’ averaging gives equal weight to each class, and is useful for balanced performance assessment across imbalanced classes.
- training_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the training loss and additional metrics.
- validation_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the validation loss and additional metrics.
- test_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the test loss and additional metrics.
- class torchgeo.trainers.MultiLabelClassificationTask(model='resnet50', weights=None, in_channels=3, num_classes=1000, loss='ce', class_weights=None, lr=0.001, patience=10, freeze_backbone=False)[source]¶
Bases:
ClassificationTask
Multi-label image classification.
- configure_metrics()[source]¶
Initialize the performance metrics.
MultilabelAccuracy
: The number of true positives divided by the dataset size. Both overall accuracy (OA) using ‘micro’ averaging and average accuracy (AA) using ‘macro’ averaging are reported. Higher values are better.MultilabelFBetaScore
: F1 score. The harmonic mean of precision and recall. Uses ‘micro’ averaging. Higher values are better.
Note
‘Micro’ averaging suits overall performance evaluation but may not reflect minority class accuracy.
‘Macro’ averaging gives equal weight to each class, and is useful for balanced performance assessment across imbalanced classes.
- training_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the training loss and additional metrics.
- validation_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the validation loss and additional metrics.
- test_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the test loss and additional metrics.
- class torchgeo.trainers.ObjectDetectionTask(model='faster-rcnn', backbone='resnet50', weights=None, in_channels=3, num_classes=1000, trainable_layers=3, lr=0.001, patience=10, freeze_backbone=False)[source]¶
Bases:
BaseTask
Object detection.
New in version 0.4.
- monitor = 'val_map'¶
Performance metric to monitor in learning rate scheduler and callbacks.
- mode = 'max'¶
Whether the goal is to minimize or maximize the performance metric to monitor.
- __init__(model='faster-rcnn', backbone='resnet50', weights=None, in_channels=3, num_classes=1000, trainable_layers=3, lr=0.001, patience=10, freeze_backbone=False)[source]¶
Initialize a new ObjectDetectionTask instance.
- Parameters:
model (str) – Name of the torchvision model to use. One of ‘faster-rcnn’, ‘fcos’, or ‘retinanet’.
backbone (str) – Name of the torchvision backbone to use. One of ‘resnet18’, ‘resnet34’, ‘resnet50’, ‘resnet101’, ‘resnet152’, ‘resnext50_32x4d’, ‘resnext101_32x8d’, ‘wide_resnet50_2’, or ‘wide_resnet101_2’.
weights (bool | None) – Initial model weights. True for ImageNet weights, False or None for random weights.
in_channels (int) – Number of input channels to model.
num_classes (int) – Number of prediction classes (including the background).
trainable_layers (int) – Number of trainable layers.
lr (float) – Learning rate for optimizer.
patience (int) – Patience for learning rate scheduler.
freeze_backbone (bool) – Freeze the backbone network to fine-tune the detection head.
Changed in version 0.4: detection_model was renamed to model.
New in version 0.5: The freeze_backbone parameter.
Changed in version 0.5: pretrained, learning_rate, and learning_rate_schedule_patience were renamed to weights, lr, and patience.
- configure_models()[source]¶
Initialize the model.
- Raises:
ValueError – If model or backbone are invalid.
- configure_metrics()[source]¶
Initialize the performance metrics.
MeanAveragePrecision
: Mean average precision (mAP) and mean average recall (mAR). Precision is the number of true positives divided by the number of true positives + false positives. Recall is the number of true positives divived by the number of true positives + false negatives. Uses ‘macro’ averaging. Higher values are better.
Note
‘Micro’ averaging suits overall performance evaluation but may not reflect minority class accuracy.
‘Macro’ averaging gives equal weight to each class, and is useful for balanced performance assessment across imbalanced classes.
- class torchgeo.trainers.PixelwiseRegressionTask(model='resnet50', backbone='resnet50', weights=None, in_channels=3, num_outputs=1, num_filters=3, loss='mse', lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]¶
Bases:
RegressionTask
LightningModule for pixelwise regression of images.
New in version 0.5.
- class torchgeo.trainers.RegressionTask(model='resnet50', backbone='resnet50', weights=None, in_channels=3, num_outputs=1, num_filters=3, loss='mse', lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]¶
Bases:
BaseTask
Regression.
- __init__(model='resnet50', backbone='resnet50', weights=None, in_channels=3, num_outputs=1, num_filters=3, loss='mse', lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]¶
Initialize a new RegressionTask instance.
- Parameters:
backbone (str) – Name of the timm or smp backbone to use. Only applicable to PixelwiseRegressionTask.
weights (torchvision.models._api.WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.
in_channels (int) – Number of input channels to model.
num_outputs (int) – Number of prediction outputs.
num_filters (int) – Number of filters. Only applicable when model=’fcn’.
loss (str) – One of ‘mse’ or ‘mae’.
lr (float) – Learning rate for optimizer.
patience (int) – Patience for learning rate scheduler.
freeze_backbone (bool) – Freeze the backbone network to linear probe the regression head. Does not support FCN models.
freeze_decoder (bool) – Freeze the decoder network to linear probe the regression head. Does not support FCN models. Only applicable to PixelwiseRegressionTask.
Changed in version 0.4: Change regression model support from torchvision.models to timm
New in version 0.5: The freeze_backbone and freeze_decoder parameters.
Changed in version 0.5: learning_rate and learning_rate_schedule_patience were renamed to lr and patience.
- configure_losses()[source]¶
Initialize the loss criterion.
- Raises:
ValueError – If loss is invalid.
- configure_metrics()[source]¶
Initialize the performance metrics.
MeanSquaredError
: The average of the squared differences between the predicted and actual values (MSE) and its square root (RMSE). Lower values are better.MeanAbsoluteError
: The average of the absolute differences between the predicted and actual values (MAE). Lower values are better.
- training_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the training loss and additional metrics.
- validation_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the validation loss and additional metrics.
- test_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the test loss and additional metrics.
- class torchgeo.trainers.SemanticSegmentationTask(model='unet', backbone='resnet50', weights=None, in_channels=3, num_classes=1000, num_filters=3, loss='ce', class_weights=None, ignore_index=None, lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]¶
Bases:
BaseTask
Semantic Segmentation.
- __init__(model='unet', backbone='resnet50', weights=None, in_channels=3, num_classes=1000, num_filters=3, loss='ce', class_weights=None, ignore_index=None, lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]¶
Initialize a new SemanticSegmentationTask instance.
- Parameters:
weights (torchvision.models._api.WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict. FCN model does not support pretrained weights. Pretrained ViT weight enums are not supported yet.
in_channels (int) – Number of input channels to model.
num_classes (int) – Number of prediction classes (including the background).
num_filters (int) – Number of filters. Only applicable when model=’fcn’.
loss (str) – Name of the loss function, currently supports ‘ce’, ‘jaccard’ or ‘focal’ loss.
class_weights (torch.Tensor | None) – Optional rescaling weight given to each class and used with ‘ce’ loss.
ignore_index (int | None) – Optional integer class index to ignore in the loss and metrics.
lr (float) – Learning rate for optimizer.
patience (int) – Patience for learning rate scheduler.
freeze_backbone (bool) – Freeze the backbone network to fine-tune the decoder and segmentation head.
freeze_decoder (bool) – Freeze the decoder network to linear probe the segmentation head.
Changed in version 0.3: ignore_zeros was renamed to ignore_index.
Changed in version 0.4: segmentation_model, encoder_name, and encoder_weights were renamed to model, backbone, and weights.
New in version 0.5: The class_weights, freeze_backbone, and freeze_decoder parameters.
Changed in version 0.5: The weights parameter now supports WeightEnums and checkpoint paths. learning_rate and learning_rate_schedule_patience were renamed to lr and patience.
Changed in version 0.6: The ignore_index parameter now works for jaccard loss.
- configure_models()[source]¶
Initialize the model.
- Raises:
ValueError – If model is invalid.
- configure_losses()[source]¶
Initialize the loss criterion.
- Raises:
ValueError – If loss is invalid.
- configure_metrics()[source]¶
Initialize the performance metrics.
MulticlassAccuracy
: Overall accuracy (OA) using ‘micro’ averaging. The number of true positives divided by the dataset size. Higher values are better.MulticlassJaccardIndex
: Intersection over union (IoU). Uses ‘micro’ averaging. Higher valuers are better.
Note
‘Micro’ averaging suits overall performance evaluation but may not reflect minority class accuracy.
‘Macro’ averaging, not used here, gives equal weight to each class, useful for balanced performance assessment across imbalanced classes.
- training_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the training loss and additional metrics.
- validation_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the validation loss and additional metrics.
- test_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the test loss and additional metrics.
- class torchgeo.trainers.BYOLTask(model='resnet50', weights=None, in_channels=3, lr=0.001, patience=10)[source]¶
Bases:
BaseTask
BYOL: Bootstrap Your Own Latent.
Reference implementation:
If you use this trainer in your research, please cite the following paper:
- monitor = 'train_loss'¶
Performance metric to monitor in learning rate scheduler and callbacks.
- __init__(model='resnet50', weights=None, in_channels=3, lr=0.001, patience=10)[source]¶
Initialize a new BYOLTask instance.
- Parameters:
weights (torchvision.models._api.WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.
in_channels (int) – Number of input channels to model.
lr (float) – Learning rate for optimizer.
patience (int) – Patience for learning rate scheduler.
Changed in version 0.4: backbone_name was renamed to backbone. Changed backbone support from torchvision.models to timm.
Changed in version 0.5: backbone, learning_rate, and learning_rate_schedule_patience were renamed to model, lr, and patience.
- training_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the training loss and additional metrics.
- Parameters:
- Returns:
The loss tensor.
- Raises:
AssertionError – If channel dimensions are incorrect.
- Return type:
- class torchgeo.trainers.MoCoTask(model='resnet50', weights=None, in_channels=3, version=3, layers=3, hidden_dim=4096, output_dim=256, lr=9.6, weight_decay=1e-06, momentum=0.9, schedule=[120, 160], temperature=1, memory_bank_size=0, moco_momentum=0.99, gather_distributed=False, size=224, grayscale_weights=None, augmentation1=None, augmentation2=None)[source]¶
Bases:
BaseTask
MoCo: Momentum Contrast.
Reference implementations:
If you use this trainer in your research, please cite the following papers:
New in version 0.5.
- monitor = 'train_loss'¶
Performance metric to monitor in learning rate scheduler and callbacks.
- __init__(model='resnet50', weights=None, in_channels=3, version=3, layers=3, hidden_dim=4096, output_dim=256, lr=9.6, weight_decay=1e-06, momentum=0.9, schedule=[120, 160], temperature=1, memory_bank_size=0, moco_momentum=0.99, gather_distributed=False, size=224, grayscale_weights=None, augmentation1=None, augmentation2=None)[source]¶
Initialize a new MoCoTask instance.
- Parameters:
weights (torchvision.models._api.WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.
in_channels (int) – Number of input channels to model.
version (int) – Version of MoCo, 1–3.
layers (int) – Number of layers in projection head (not used in v1, 2 for v1/2, 3 for v3).
hidden_dim (int) – Number of hidden dimensions in projection head (not used in v1, 2048 for v2, 4096 for v3).
output_dim (int) – Number of output dimensions in projection head (not used in v1, 128 for v2, 256 for v3).
lr (float) – Learning rate (0.03 x batch_size / 256 for v1/2, 0.6 x batch_size / 256 for v3).
weight_decay (float) – Weight decay coefficient (1e-4 for v1/2, 1e-6 for v3).
momentum (float) – Momentum of SGD solver (v1/2 only).
schedule (Sequence[int]) – Epochs at which to drop lr by 10x (v1/2 only).
temperature (float) – Temperature used in InfoNCE loss (0.07 for v1/2, 1 for v3).
memory_bank_size (int) – Size of memory bank (65536 for v1/2, 0 for v3).
moco_momentum (float) – MoCo momentum of updating key encoder (0.999 for v1/2, 0.99 for v3)
gather_distributed (bool) – Gather negatives from all GPUs during distributed training (ignored if memory_bank_size > 0).
size (int) – Size of patch to crop.
grayscale_weights (torch.Tensor | None) – Weight vector for grayscale computation, see
RandomGrayscale
. Only used whenaugmentations=None
. Defaults to average of all bands.augmentation1 (torch.nn.modules.module.Module | None) – Data augmentation for 1st branch. Defaults to MoCo augmentation.
augmentation2 (torch.nn.modules.module.Module | None) – Data augmentation for 2nd branch. Defaults to MoCo augmentation.
- Raises:
AssertionError – If an invalid version of MoCo is requested.
- Warns:
UserWarning – If hyperparameters do not match MoCo version requested.
- configure_optimizers()[source]¶
Initialize the optimizer and learning rate scheduler.
- Returns:
Optimizer and learning rate scheduler.
- Return type:
OptimizerLRSchedulerConfig
- forward(x)[source]¶
Forward pass of the model.
- Parameters:
x (Tensor) – Mini-batch of images.
- Returns:
Output of the model and backbone
- Return type:
- class torchgeo.trainers.SimCLRTask(model='resnet50', weights=None, in_channels=3, version=2, layers=3, hidden_dim=None, output_dim=None, lr=4.8, momentum=0.9, weight_decay=0.0001, temperature=0.07, memory_bank_size=64000, gather_distributed=False, size=224, grayscale_weights=None, augmentations=None)[source]¶
Bases:
BaseTask
SimCLR: a simple framework for contrastive learning of visual representations.
Reference implementation:
If you use this trainer in your research, please cite the following papers:
New in version 0.5.
- monitor = 'train_loss'¶
Performance metric to monitor in learning rate scheduler and callbacks.
- __init__(model='resnet50', weights=None, in_channels=3, version=2, layers=3, hidden_dim=None, output_dim=None, lr=4.8, momentum=0.9, weight_decay=0.0001, temperature=0.07, memory_bank_size=64000, gather_distributed=False, size=224, grayscale_weights=None, augmentations=None)[source]¶
Initialize a new SimCLRTask instance.
New in version 0.6: The momentum parameter.
- Parameters:
weights (torchvision.models._api.WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.
in_channels (int) – Number of input channels to model.
version (int) – Version of SimCLR, 1–2.
layers (int) – Number of layers in projection head (2 for v1, 3+ for v2).
hidden_dim (int | None) – Number of hidden dimensions in projection head (defaults to output dimension of model).
output_dim (int | None) – Number of output dimensions in projection head (defaults to output dimension of model).
lr (float) – Learning rate (0.3 x batch_size / 256 is recommended).
momentum (float) – Momentum factor.
weight_decay (float) – Weight decay coefficient (1e-6 for v1, 1e-4 for v2).
temperature (float) – Temperature used in NT-Xent loss.
memory_bank_size (int) – Size of memory bank (0 for v1, 64K for v2).
gather_distributed (bool) – Gather negatives from all GPUs during distributed training (ignored if memory_bank_size > 0).
size (int) – Size of patch to crop.
grayscale_weights (torch.Tensor | None) – Weight vector for grayscale computation, see
RandomGrayscale
. Only used whenaugmentations=None
. Defaults to average of all bands.augmentations (torch.nn.modules.module.Module | None) – Data augmentation. Defaults to SimCLR augmentation.
- Raises:
AssertionError – If an invalid version of SimCLR is requested.
- Warns:
UserWarning – If hyperparameters do not match SimCLR version requested.
- forward(x)[source]¶
Forward pass of the model.
- Parameters:
x (Tensor) – Mini-batch of images.
- Returns:
Output of the model and backbone.
- Return type:
- training_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the training loss and additional metrics.
- Parameters:
- Returns:
The loss tensor.
- Raises:
AssertionError – If channel dimensions are incorrect.
- Return type:
- class torchgeo.trainers.BaseTask(ignore=None)[source]¶
Bases:
LightningModule
,ABC
Abstract base class for all TorchGeo trainers.
New in version 0.5.
- monitor = 'val_loss'¶
Performance metric to monitor in learning rate scheduler and callbacks.
- mode = 'min'¶
Whether the goal is to minimize or maximize the performance metric to monitor.
- __init__(ignore=None)[source]¶
Initialize a new BaseTask instance.
- Parameters:
ignore (collections.abc.Sequence[str] | str | None) – Arguments to skip when saving hyperparameters.
- class torchgeo.trainers.IOBenchTask(ignore=None)[source]¶
Bases:
BaseTask
I/O benchmarking.
New in version 0.6.
- configure_optimizers()[source]¶
Initialize the optimizer.
- Returns:
Optimizer.
- Return type:
OptimizerLRSchedulerConfig