torchgeo.trainers¶
TorchGeo trainers.
- class torchgeo.trainers.ClassificationTask(model='resnet50', weights=None, in_channels=3, num_classes=1000, loss='ce', class_weights=None, lr=0.001, patience=10, freeze_backbone=False)[source]¶
Bases:
BaseTask
Image classification.
- __init__(model='resnet50', weights=None, in_channels=3, num_classes=1000, loss='ce', class_weights=None, lr=0.001, patience=10, freeze_backbone=False)[source]¶
Initialize a new ClassificationTask instance.
- Parameters:
weights (Optional[Union[WeightsEnum, str, bool]]) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.
in_channels (int) – Number of input channels to model.
num_classes (int) – Number of prediction classes.
loss (str) – One of ‘ce’, ‘bce’, ‘jaccard’, or ‘focal’.
class_weights (Optional[Tensor]) – Optional rescaling weight given to each class and used with ‘ce’ loss.
lr (float) – Learning rate for optimizer.
patience (int) – Patience for learning rate scheduler.
freeze_backbone (bool) – Freeze the backbone network to linear probe the classifier head.
Changed in version 0.4: classification_model was renamed to model.
New in version 0.5: The class_weights and freeze_backbone parameters.
Changed in version 0.5: learning_rate and learning_rate_schedule_patience were renamed to lr and patience.
- configure_losses()[source]¶
Initialize the loss criterion.
- Raises:
ValueError – If loss is invalid.
- training_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the training loss and additional metrics.
- validation_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the validation loss and additional metrics.
- test_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the test loss and additional metrics.
- class torchgeo.trainers.MultiLabelClassificationTask(model='resnet50', weights=None, in_channels=3, num_classes=1000, loss='ce', class_weights=None, lr=0.001, patience=10, freeze_backbone=False)[source]¶
Bases:
ClassificationTask
Multi-label image classification.
- training_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the training loss and additional metrics.
- validation_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the validation loss and additional metrics.
- test_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the test loss and additional metrics.
- class torchgeo.trainers.ObjectDetectionTask(model='faster-rcnn', backbone='resnet50', weights=None, in_channels=3, num_classes=1000, trainable_layers=3, lr=0.001, patience=10, freeze_backbone=False)[source]¶
Bases:
BaseTask
Object detection.
New in version 0.4.
- monitor = 'val_map'¶
Performance metric to monitor in learning rate scheduler and callbacks.
- mode = 'max'¶
Whether the goal is to minimize or maximize the performance metric to monitor.
- __init__(model='faster-rcnn', backbone='resnet50', weights=None, in_channels=3, num_classes=1000, trainable_layers=3, lr=0.001, patience=10, freeze_backbone=False)[source]¶
Initialize a new ObjectDetectionTask instance.
- Parameters:
model (str) – Name of the torchvision model to use. One of ‘faster-rcnn’, ‘fcos’, or ‘retinanet’.
backbone (str) – Name of the torchvision backbone to use. One of ‘resnet18’, ‘resnet34’, ‘resnet50’, ‘resnet101’, ‘resnet152’, ‘resnext50_32x4d’, ‘resnext101_32x8d’, ‘wide_resnet50_2’, or ‘wide_resnet101_2’.
weights (Optional[bool]) – Initial model weights. True for ImageNet weights, False or None for random weights.
in_channels (int) – Number of input channels to model.
num_classes (int) – Number of prediction classes.
trainable_layers (int) – Number of trainable layers.
lr (float) – Learning rate for optimizer.
patience (int) – Patience for learning rate scheduler.
freeze_backbone (bool) – Freeze the backbone network to fine-tune the detection head.
Changed in version 0.4: detection_model was renamed to model.
New in version 0.5: The freeze_backbone parameter.
Changed in version 0.5: pretrained, learning_rate, and learning_rate_schedule_patience were renamed to weights, lr, and patience.
- configure_models()[source]¶
Initialize the model.
- Raises:
ValueError – If model or backbone are invalid.
- class torchgeo.trainers.PixelwiseRegressionTask(model='resnet50', backbone='resnet50', weights=None, in_channels=3, num_outputs=1, num_filters=3, loss='mse', lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]¶
Bases:
RegressionTask
LightningModule for pixelwise regression of images.
New in version 0.5.
- class torchgeo.trainers.RegressionTask(model='resnet50', backbone='resnet50', weights=None, in_channels=3, num_outputs=1, num_filters=3, loss='mse', lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]¶
Bases:
BaseTask
Regression.
- __init__(model='resnet50', backbone='resnet50', weights=None, in_channels=3, num_outputs=1, num_filters=3, loss='mse', lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]¶
Initialize a new RegressionTask instance.
- Parameters:
backbone (str) – Name of the timm or smp backbone to use. Only applicable to PixelwiseRegressionTask.
weights (Optional[Union[WeightsEnum, str, bool]]) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.
in_channels (int) – Number of input channels to model.
num_outputs (int) – Number of prediction outputs.
num_filters (int) – Number of filters. Only applicable when model=’fcn’.
loss (str) – One of ‘mse’ or ‘mae’.
lr (float) – Learning rate for optimizer.
patience (int) – Patience for learning rate scheduler.
freeze_backbone (bool) – Freeze the backbone network to linear probe the regression head. Does not support FCN models.
freeze_decoder (bool) – Freeze the decoder network to linear probe the regression head. Does not support FCN models. Only applicable to PixelwiseRegressionTask.
Changed in version 0.4: Change regression model support from torchvision.models to timm
New in version 0.5: The freeze_backbone and freeze_decoder parameters.
Changed in version 0.5: learning_rate and learning_rate_schedule_patience were renamed to lr and patience.
- configure_losses()[source]¶
Initialize the loss criterion.
- Raises:
ValueError – If loss is invalid.
- training_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the training loss and additional metrics.
- validation_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the validation loss and additional metrics.
- test_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the test loss and additional metrics.
- class torchgeo.trainers.SemanticSegmentationTask(model='unet', backbone='resnet50', weights=None, in_channels=3, num_classes=1000, num_filters=3, loss='ce', class_weights=None, ignore_index=None, lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]¶
Bases:
BaseTask
Semantic Segmentation.
- __init__(model='unet', backbone='resnet50', weights=None, in_channels=3, num_classes=1000, num_filters=3, loss='ce', class_weights=None, ignore_index=None, lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]¶
Inititalize a new SemanticSegmentationTask instance.
- Parameters:
weights (Optional[Union[WeightsEnum, str, bool]]) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict. FCN model does not support pretrained weights. Pretrained ViT weight enums are not supported yet.
in_channels (int) – Number of input channels to model.
num_classes (int) – Number of prediction classes.
num_filters (int) – Number of filters. Only applicable when model=’fcn’.
loss (str) – Name of the loss function, currently supports ‘ce’, ‘jaccard’ or ‘focal’ loss.
class_weights (Optional[Tensor]) – Optional rescaling weight given to each class and used with ‘ce’ loss.
ignore_index (Optional[int]) – Optional integer class index to ignore in the loss and metrics.
lr (float) – Learning rate for optimizer.
patience (int) – Patience for learning rate scheduler.
freeze_backbone (bool) – Freeze the backbone network to fine-tune the decoder and segmentation head.
freeze_decoder (bool) – Freeze the decoder network to linear probe the segmentation head.
- Warns:
UserWarning – When loss=’jaccard’ and ignore_index is specified.
Changed in version 0.3: ignore_zeros was renamed to ignore_index.
Changed in version 0.4: segmentation_model, encoder_name, and encoder_weights were renamed to model, backbone, and weights.
Changed in version 0.5: The weights parameter now supports WeightEnums and checkpoint paths. learning_rate and learning_rate_schedule_patience were renamed to lr and patience.
- configure_losses()[source]¶
Initialize the loss criterion.
- Raises:
ValueError – If loss is invalid.
- configure_models()[source]¶
Initialize the model.
- Raises:
ValueError – If model is invalid.
- training_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the training loss and additional metrics.
- validation_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the validation loss and additional metrics.
- test_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the test loss and additional metrics.
- class torchgeo.trainers.BYOLTask(model='resnet50', weights=None, in_channels=3, lr=0.001, patience=10)[source]¶
Bases:
BaseTask
BYOL: Bootstrap Your Own Latent.
Reference implementation:
If you use this trainer in your research, please cite the following paper:
- monitor = 'train_loss'¶
Performance metric to monitor in learning rate scheduler and callbacks.
- __init__(model='resnet50', weights=None, in_channels=3, lr=0.001, patience=10)[source]¶
Initialize a new BYOLTask instance.
- Parameters:
weights (Optional[Union[WeightsEnum, str, bool]]) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.
in_channels (int) – Number of input channels to model.
lr (float) – Learning rate for optimizer.
patience (int) – Patience for learning rate scheduler.
Changed in version 0.4: backbone_name was renamed to backbone. Changed backbone support from torchvision.models to timm.
Changed in version 0.5: backbone, learning_rate, and learning_rate_schedule_patience were renamed to model, lr, and patience.
- training_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the training loss and additional metrics.
- Parameters:
- Returns:
The loss tensor.
- Raises:
AssertionError – If channel dimensions are incorrect.
- Return type:
- class torchgeo.trainers.MoCoTask(model='resnet50', weights=None, in_channels=3, version=3, layers=3, hidden_dim=4096, output_dim=256, lr=9.6, weight_decay=1e-06, momentum=0.9, schedule=[120, 160], temperature=1, memory_bank_size=0, moco_momentum=0.99, gather_distributed=False, size=224, grayscale_weights=None, augmentation1=None, augmentation2=None)[source]¶
Bases:
BaseTask
MoCo: Momentum Contrast.
Reference implementations:
If you use this trainer in your research, please cite the following papers:
New in version 0.5.
- monitor = 'train_loss'¶
Performance metric to monitor in learning rate scheduler and callbacks.
- __init__(model='resnet50', weights=None, in_channels=3, version=3, layers=3, hidden_dim=4096, output_dim=256, lr=9.6, weight_decay=1e-06, momentum=0.9, schedule=[120, 160], temperature=1, memory_bank_size=0, moco_momentum=0.99, gather_distributed=False, size=224, grayscale_weights=None, augmentation1=None, augmentation2=None)[source]¶
Initialize a new MoCoTask instance.
- Parameters:
weights (Optional[Union[WeightsEnum, str, bool]]) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.
in_channels (int) – Number of input channels to model.
version (int) – Version of MoCo, 1–3.
layers (int) – Number of layers in projection head (not used in v1, 2 for v1/2, 3 for v3).
hidden_dim (int) – Number of hidden dimensions in projection head (not used in v1, 2048 for v2, 4096 for v3).
output_dim (int) – Number of output dimensions in projection head (not used in v1, 128 for v2, 256 for v3).
lr (float) – Learning rate (0.03 x batch_size / 256 for v1/2, 0.6 x batch_size / 256 for v3).
weight_decay (float) – Weight decay coefficient (1e-4 for v1/2, 1e-6 for v3).
momentum (float) – Momentum of SGD solver (v1/2 only).
schedule (Sequence[int]) – Epochs at which to drop lr by 10x (v1/2 only).
temperature (float) – Temperature used in InfoNCE loss (0.07 for v1/2, 1 for v3).
memory_bank_size (int) – Size of memory bank (65536 for v1/2, 0 for v3).
moco_momentum (float) – MoCo momentum of updating key encoder (0.999 for v1/2, 0.99 for v3)
gather_distributed (bool) – Gather negatives from all GPUs during distributed training (ignored if memory_bank_size > 0).
size (int) – Size of patch to crop.
grayscale_weights (Optional[Tensor]) – Weight vector for grayscale computation, see
RandomGrayscale
. Only used whenaugmentations=None
. Defaults to average of all bands.augmentation1 (Optional[Module]) – Data augmentation for 1st branch. Defaults to MoCo augmentation.
augmentation2 (Optional[Module]) – Data augmentation for 2nd branch. Defaults to MoCo augmentation.
- Raises:
AssertionError – If an invalid version of MoCo is requested.
- Warns:
UserWarning – If hyperparameters do not match MoCo version requested.
- configure_optimizers()[source]¶
Initialize the optimizer and learning rate scheduler.
- Returns:
Optimizer and learning rate scheduler.
- Return type:
OptimizerLRSchedulerConfig
- forward(x)[source]¶
Forward pass of the model.
- Parameters:
x (Tensor) – Mini-batch of images.
- Returns:
Output of the model and backbone
- Return type:
- class torchgeo.trainers.SimCLRTask(model='resnet50', weights=None, in_channels=3, version=2, layers=3, hidden_dim=None, output_dim=None, lr=4.8, weight_decay=0.0001, temperature=0.07, memory_bank_size=64000, gather_distributed=False, size=224, grayscale_weights=None, augmentations=None)[source]¶
Bases:
BaseTask
SimCLR: a simple framework for contrastive learning of visual representations.
Reference implementation:
If you use this trainer in your research, please cite the following papers:
New in version 0.5.
- monitor = 'train_loss'¶
Performance metric to monitor in learning rate scheduler and callbacks.
- __init__(model='resnet50', weights=None, in_channels=3, version=2, layers=3, hidden_dim=None, output_dim=None, lr=4.8, weight_decay=0.0001, temperature=0.07, memory_bank_size=64000, gather_distributed=False, size=224, grayscale_weights=None, augmentations=None)[source]¶
Initialize a new SimCLRTask instance.
- Parameters:
weights (Optional[Union[WeightsEnum, str, bool]]) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.
in_channels (int) – Number of input channels to model.
version (int) – Version of SimCLR, 1–2.
layers (int) – Number of layers in projection head (2 for v1, 3+ for v2).
hidden_dim (Optional[int]) – Number of hidden dimensions in projection head (defaults to output dimension of model).
output_dim (Optional[int]) – Number of output dimensions in projection head (defaults to output dimension of model).
lr (float) – Learning rate (0.3 x batch_size / 256 is recommended).
weight_decay (float) – Weight decay coefficient (1e-6 for v1, 1e-4 for v2).
temperature (float) – Temperature used in NT-Xent loss.
memory_bank_size (int) – Size of memory bank (0 for v1, 64K for v2).
gather_distributed (bool) – Gather negatives from all GPUs during distributed training (ignored if memory_bank_size > 0).
size (int) – Size of patch to crop.
grayscale_weights (Optional[Tensor]) – Weight vector for grayscale computation, see
RandomGrayscale
. Only used whenaugmentations=None
. Defaults to average of all bands.augmentations (Optional[Module]) – Data augmentation. Defaults to SimCLR augmentation.
- Raises:
AssertionError – If an invalid version of SimCLR is requested.
- Warns:
UserWarning – If hyperparameters do not match SimCLR version requested.
- forward(x)[source]¶
Forward pass of the model.
- Parameters:
x (Tensor) – Mini-batch of images.
- Returns:
Output of the model and backbone.
- Return type:
- training_step(batch, batch_idx, dataloader_idx=0)[source]¶
Compute the training loss and additional metrics.
- Parameters:
- Returns:
The loss tensor.
- Raises:
AssertionError – If channel dimensions are incorrect.
- Return type:
- class torchgeo.trainers.BaseTask(ignore=None)[source]¶
Bases:
LightningModule
,ABC
Abstract base class for all TorchGeo trainers.
New in version 0.5.
- monitor = 'val_loss'¶
Performance metric to monitor in learning rate scheduler and callbacks.
- mode = 'min'¶
Whether the goal is to minimize or maximize the performance metric to monitor.