Shortcuts

torchgeo.trainers

TorchGeo trainers.

class torchgeo.trainers.ClassificationTask(model='resnet50', weights=None, in_channels=3, num_classes=1000, loss='ce', class_weights=None, lr=0.001, patience=10, freeze_backbone=False)[source]

Bases: BaseTask

Image classification.

__init__(model='resnet50', weights=None, in_channels=3, num_classes=1000, loss='ce', class_weights=None, lr=0.001, patience=10, freeze_backbone=False)[source]

Initialize a new ClassificationTask instance.

Parameters:
  • model (str) – Name of the timm model to use.

  • weights (torchvision.models._api.WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.

  • in_channels (int) – Number of input channels to model.

  • num_classes (int) – Number of prediction classes.

  • loss (str) – One of ‘ce’, ‘bce’, ‘jaccard’, or ‘focal’.

  • class_weights (torch.Tensor | None) – Optional rescaling weight given to each class and used with ‘ce’ loss.

  • lr (float) – Learning rate for optimizer.

  • patience (int) – Patience for learning rate scheduler.

  • freeze_backbone (bool) – Freeze the backbone network to linear probe the classifier head.

Changed in version 0.4: classification_model was renamed to model.

New in version 0.5: The class_weights and freeze_backbone parameters.

Changed in version 0.5: learning_rate and learning_rate_schedule_patience were renamed to lr and patience.

configure_models()[source]

Initialize the model.

configure_losses()[source]

Initialize the loss criterion.

Raises:

ValueError – If loss is invalid.

configure_metrics()[source]

Initialize the performance metrics.

  • MulticlassAccuracy: The number of true positives divided by the dataset size. Both overall accuracy (OA) using ‘micro’ averaging and average accuracy (AA) using ‘macro’ averaging are reported. Higher values are better.

  • MulticlassJaccardIndex: Intersection over union (IoU). Uses ‘macro’ averaging. Higher valuers are better.

  • MulticlassFBetaScore: F1 score. The harmonic mean of precision and recall. Uses ‘micro’ averaging. Higher values are better.

Note

  • ‘Micro’ averaging suits overall performance evaluation but may not reflect minority class accuracy.

  • ‘Macro’ averaging gives equal weight to each class, and is useful for balanced performance assessment across imbalanced classes.

training_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the training loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the validation loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

test_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the test loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

predict_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the predicted class probabilities.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

Output predicted probabilities.

Return type:

Tensor

class torchgeo.trainers.MultiLabelClassificationTask(model='resnet50', weights=None, in_channels=3, num_classes=1000, loss='ce', class_weights=None, lr=0.001, patience=10, freeze_backbone=False)[source]

Bases: ClassificationTask

Multi-label image classification.

configure_metrics()[source]

Initialize the performance metrics.

  • MultilabelAccuracy: The number of true positives divided by the dataset size. Both overall accuracy (OA) using ‘micro’ averaging and average accuracy (AA) using ‘macro’ averaging are reported. Higher values are better.

  • MultilabelFBetaScore: F1 score. The harmonic mean of precision and recall. Uses ‘micro’ averaging. Higher values are better.

Note

  • ‘Micro’ averaging suits overall performance evaluation but may not reflect minority class accuracy.

  • ‘Macro’ averaging gives equal weight to each class, and is useful for balanced performance assessment across imbalanced classes.

training_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the training loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the validation loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

test_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the test loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

predict_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the predicted class probabilities.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

Output predicted probabilities.

Return type:

Tensor

class torchgeo.trainers.ObjectDetectionTask(model='faster-rcnn', backbone='resnet50', weights=None, in_channels=3, num_classes=1000, trainable_layers=3, lr=0.001, patience=10, freeze_backbone=False)[source]

Bases: BaseTask

Object detection.

New in version 0.4.

monitor = 'val_map'

Performance metric to monitor in learning rate scheduler and callbacks.

mode = 'max'

Whether the goal is to minimize or maximize the performance metric to monitor.

__init__(model='faster-rcnn', backbone='resnet50', weights=None, in_channels=3, num_classes=1000, trainable_layers=3, lr=0.001, patience=10, freeze_backbone=False)[source]

Initialize a new ObjectDetectionTask instance.

Parameters:
  • model (str) – Name of the torchvision model to use. One of ‘faster-rcnn’, ‘fcos’, or ‘retinanet’.

  • backbone (str) – Name of the torchvision backbone to use. One of ‘resnet18’, ‘resnet34’, ‘resnet50’, ‘resnet101’, ‘resnet152’, ‘resnext50_32x4d’, ‘resnext101_32x8d’, ‘wide_resnet50_2’, or ‘wide_resnet101_2’.

  • weights (bool | None) – Initial model weights. True for ImageNet weights, False or None for random weights.

  • in_channels (int) – Number of input channels to model.

  • num_classes (int) – Number of prediction classes (including the background).

  • trainable_layers (int) – Number of trainable layers.

  • lr (float) – Learning rate for optimizer.

  • patience (int) – Patience for learning rate scheduler.

  • freeze_backbone (bool) – Freeze the backbone network to fine-tune the detection head.

Changed in version 0.4: detection_model was renamed to model.

New in version 0.5: The freeze_backbone parameter.

Changed in version 0.5: pretrained, learning_rate, and learning_rate_schedule_patience were renamed to weights, lr, and patience.

configure_models()[source]

Initialize the model.

Raises:

ValueError – If model or backbone are invalid.

configure_metrics()[source]

Initialize the performance metrics.

  • MeanAveragePrecision: Mean average precision (mAP) and mean average recall (mAR). Precision is the number of true positives divided by the number of true positives + false positives. Recall is the number of true positives divived by the number of true positives + false negatives. Uses ‘macro’ averaging. Higher values are better.

Note

  • ‘Micro’ averaging suits overall performance evaluation but may not reflect minority class accuracy.

  • ‘Macro’ averaging gives equal weight to each class, and is useful for balanced performance assessment across imbalanced classes.

training_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the training loss.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the validation metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

test_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the test metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

predict_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the predicted bounding boxes.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

Output predicted probabilities.

Return type:

list[dict[str, torch.Tensor]]

class torchgeo.trainers.PixelwiseRegressionTask(model='resnet50', backbone='resnet50', weights=None, in_channels=3, num_outputs=1, num_filters=3, loss='mse', lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]

Bases: RegressionTask

LightningModule for pixelwise regression of images.

New in version 0.5.

configure_models()[source]

Initialize the model.

class torchgeo.trainers.RegressionTask(model='resnet50', backbone='resnet50', weights=None, in_channels=3, num_outputs=1, num_filters=3, loss='mse', lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]

Bases: BaseTask

Regression.

__init__(model='resnet50', backbone='resnet50', weights=None, in_channels=3, num_outputs=1, num_filters=3, loss='mse', lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]

Initialize a new RegressionTask instance.

Parameters:
  • model (str) – Name of the timm or smp model to use.

  • backbone (str) – Name of the timm or smp backbone to use. Only applicable to PixelwiseRegressionTask.

  • weights (torchvision.models._api.WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.

  • in_channels (int) – Number of input channels to model.

  • num_outputs (int) – Number of prediction outputs.

  • num_filters (int) – Number of filters. Only applicable when model=’fcn’.

  • loss (str) – One of ‘mse’ or ‘mae’.

  • lr (float) – Learning rate for optimizer.

  • patience (int) – Patience for learning rate scheduler.

  • freeze_backbone (bool) – Freeze the backbone network to linear probe the regression head. Does not support FCN models.

  • freeze_decoder (bool) – Freeze the decoder network to linear probe the regression head. Does not support FCN models. Only applicable to PixelwiseRegressionTask.

Changed in version 0.4: Change regression model support from torchvision.models to timm

New in version 0.5: The freeze_backbone and freeze_decoder parameters.

Changed in version 0.5: learning_rate and learning_rate_schedule_patience were renamed to lr and patience.

configure_models()[source]

Initialize the model.

configure_losses()[source]

Initialize the loss criterion.

Raises:

ValueError – If loss is invalid.

configure_metrics()[source]

Initialize the performance metrics.

  • MeanSquaredError: The average of the squared differences between the predicted and actual values (MSE) and its square root (RMSE). Lower values are better.

  • MeanAbsoluteError: The average of the absolute differences between the predicted and actual values (MAE). Lower values are better.

training_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the training loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the validation loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

test_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the test loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

predict_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the predicted regression values.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

Output predicted probabilities.

Return type:

Tensor

class torchgeo.trainers.SemanticSegmentationTask(model='unet', backbone='resnet50', weights=None, in_channels=3, num_classes=1000, num_filters=3, loss='ce', class_weights=None, ignore_index=None, lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]

Bases: BaseTask

Semantic Segmentation.

__init__(model='unet', backbone='resnet50', weights=None, in_channels=3, num_classes=1000, num_filters=3, loss='ce', class_weights=None, ignore_index=None, lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]

Initialize a new SemanticSegmentationTask instance.

Parameters:
  • model (str) – Name of the smp model to use.

  • backbone (str) – Name of the timm or smp backbone to use.

  • weights (torchvision.models._api.WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict. FCN model does not support pretrained weights. Pretrained ViT weight enums are not supported yet.

  • in_channels (int) – Number of input channels to model.

  • num_classes (int) – Number of prediction classes (including the background).

  • num_filters (int) – Number of filters. Only applicable when model=’fcn’.

  • loss (str) – Name of the loss function, currently supports ‘ce’, ‘jaccard’ or ‘focal’ loss.

  • class_weights (torch.Tensor | None) – Optional rescaling weight given to each class and used with ‘ce’ loss.

  • ignore_index (int | None) – Optional integer class index to ignore in the loss and metrics.

  • lr (float) – Learning rate for optimizer.

  • patience (int) – Patience for learning rate scheduler.

  • freeze_backbone (bool) – Freeze the backbone network to fine-tune the decoder and segmentation head.

  • freeze_decoder (bool) – Freeze the decoder network to linear probe the segmentation head.

Changed in version 0.3: ignore_zeros was renamed to ignore_index.

Changed in version 0.4: segmentation_model, encoder_name, and encoder_weights were renamed to model, backbone, and weights.

New in version 0.5: The class_weights, freeze_backbone, and freeze_decoder parameters.

Changed in version 0.5: The weights parameter now supports WeightEnums and checkpoint paths. learning_rate and learning_rate_schedule_patience were renamed to lr and patience.

Changed in version 0.6: The ignore_index parameter now works for jaccard loss.

configure_models()[source]

Initialize the model.

Raises:

ValueError – If model is invalid.

configure_losses()[source]

Initialize the loss criterion.

Raises:

ValueError – If loss is invalid.

configure_metrics()[source]

Initialize the performance metrics.

  • MulticlassAccuracy: Overall accuracy (OA) using ‘micro’ averaging. The number of true positives divided by the dataset size. Higher values are better.

  • MulticlassJaccardIndex: Intersection over union (IoU). Uses ‘micro’ averaging. Higher valuers are better.

Note

  • ‘Micro’ averaging suits overall performance evaluation but may not reflect minority class accuracy.

  • ‘Macro’ averaging, not used here, gives equal weight to each class, useful for balanced performance assessment across imbalanced classes.

training_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the training loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the validation loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

test_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the test loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

predict_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the predicted class probabilities.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

Output predicted probabilities.

Return type:

Tensor

class torchgeo.trainers.BYOLTask(model='resnet50', weights=None, in_channels=3, lr=0.001, patience=10)[source]

Bases: BaseTask

BYOL: Bootstrap Your Own Latent.

Reference implementation:

If you use this trainer in your research, please cite the following paper:

monitor = 'train_loss'

Performance metric to monitor in learning rate scheduler and callbacks.

__init__(model='resnet50', weights=None, in_channels=3, lr=0.001, patience=10)[source]

Initialize a new BYOLTask instance.

Parameters:
  • model (str) – Name of the timm model to use.

  • weights (torchvision.models._api.WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.

  • in_channels (int) – Number of input channels to model.

  • lr (float) – Learning rate for optimizer.

  • patience (int) – Patience for learning rate scheduler.

Changed in version 0.4: backbone_name was renamed to backbone. Changed backbone support from torchvision.models to timm.

Changed in version 0.5: backbone, learning_rate, and learning_rate_schedule_patience were renamed to model, lr, and patience.

configure_models()[source]

Initialize the model.

training_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the training loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Raises:

AssertionError – If channel dimensions are incorrect.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]

No-op, does nothing.

test_step(batch, batch_idx, dataloader_idx=0)[source]

No-op, does nothing.

predict_step(batch, batch_idx, dataloader_idx=0)[source]

No-op, does nothing.

class torchgeo.trainers.MoCoTask(model='resnet50', weights=None, in_channels=3, version=3, layers=3, hidden_dim=4096, output_dim=256, lr=9.6, weight_decay=1e-06, momentum=0.9, schedule=[120, 160], temperature=1, memory_bank_size=0, moco_momentum=0.99, gather_distributed=False, size=224, grayscale_weights=None, augmentation1=None, augmentation2=None)[source]

Bases: BaseTask

MoCo: Momentum Contrast.

Reference implementations:

If you use this trainer in your research, please cite the following papers:

New in version 0.5.

monitor = 'train_loss'

Performance metric to monitor in learning rate scheduler and callbacks.

__init__(model='resnet50', weights=None, in_channels=3, version=3, layers=3, hidden_dim=4096, output_dim=256, lr=9.6, weight_decay=1e-06, momentum=0.9, schedule=[120, 160], temperature=1, memory_bank_size=0, moco_momentum=0.99, gather_distributed=False, size=224, grayscale_weights=None, augmentation1=None, augmentation2=None)[source]

Initialize a new MoCoTask instance.

Parameters:
  • model (str) – Name of the timm model to use.

  • weights (torchvision.models._api.WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.

  • in_channels (int) – Number of input channels to model.

  • version (int) – Version of MoCo, 1–3.

  • layers (int) – Number of layers in projection head (not used in v1, 2 for v1/2, 3 for v3).

  • hidden_dim (int) – Number of hidden dimensions in projection head (not used in v1, 2048 for v2, 4096 for v3).

  • output_dim (int) – Number of output dimensions in projection head (not used in v1, 128 for v2, 256 for v3).

  • lr (float) – Learning rate (0.03 x batch_size / 256 for v1/2, 0.6 x batch_size / 256 for v3).

  • weight_decay (float) – Weight decay coefficient (1e-4 for v1/2, 1e-6 for v3).

  • momentum (float) – Momentum of SGD solver (v1/2 only).

  • schedule (Sequence[int]) – Epochs at which to drop lr by 10x (v1/2 only).

  • temperature (float) – Temperature used in InfoNCE loss (0.07 for v1/2, 1 for v3).

  • memory_bank_size (int) – Size of memory bank (65536 for v1/2, 0 for v3).

  • moco_momentum (float) – MoCo momentum of updating key encoder (0.999 for v1/2, 0.99 for v3)

  • gather_distributed (bool) – Gather negatives from all GPUs during distributed training (ignored if memory_bank_size > 0).

  • size (int) – Size of patch to crop.

  • grayscale_weights (torch.Tensor | None) – Weight vector for grayscale computation, see RandomGrayscale. Only used when augmentations=None. Defaults to average of all bands.

  • augmentation1 (torch.nn.modules.module.Module | None) – Data augmentation for 1st branch. Defaults to MoCo augmentation.

  • augmentation2 (torch.nn.modules.module.Module | None) – Data augmentation for 2nd branch. Defaults to MoCo augmentation.

Raises:

AssertionError – If an invalid version of MoCo is requested.

Warns:

UserWarning – If hyperparameters do not match MoCo version requested.

configure_models()[source]

Initialize the model.

configure_losses()[source]

Initialize the loss criterion.

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Returns:

Optimizer and learning rate scheduler.

Return type:

OptimizerLRSchedulerConfig

forward(x)[source]

Forward pass of the model.

Parameters:

x (Tensor) – Mini-batch of images.

Returns:

Output of the model and backbone

Return type:

tuple[torch.Tensor, torch.Tensor]

forward_momentum(x)[source]

Forward pass of the momentum model.

Parameters:

x (Tensor) – Mini-batch of images.

Returns:

Output from the momentum model.

Return type:

Tensor

training_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the training loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]

No-op, does nothing.

test_step(batch, batch_idx, dataloader_idx=0)[source]

No-op, does nothing.

predict_step(batch, batch_idx, dataloader_idx=0)[source]

No-op, does nothing.

class torchgeo.trainers.SimCLRTask(model='resnet50', weights=None, in_channels=3, version=2, layers=3, hidden_dim=None, output_dim=None, lr=4.8, momentum=0.9, weight_decay=0.0001, temperature=0.07, memory_bank_size=64000, gather_distributed=False, size=224, grayscale_weights=None, augmentations=None)[source]

Bases: BaseTask

SimCLR: a simple framework for contrastive learning of visual representations.

Reference implementation:

If you use this trainer in your research, please cite the following papers:

New in version 0.5.

monitor = 'train_loss'

Performance metric to monitor in learning rate scheduler and callbacks.

__init__(model='resnet50', weights=None, in_channels=3, version=2, layers=3, hidden_dim=None, output_dim=None, lr=4.8, momentum=0.9, weight_decay=0.0001, temperature=0.07, memory_bank_size=64000, gather_distributed=False, size=224, grayscale_weights=None, augmentations=None)[source]

Initialize a new SimCLRTask instance.

New in version 0.6: The momentum parameter.

Parameters:
  • model (str) – Name of the timm model to use.

  • weights (torchvision.models._api.WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.

  • in_channels (int) – Number of input channels to model.

  • version (int) – Version of SimCLR, 1–2.

  • layers (int) – Number of layers in projection head (2 for v1, 3+ for v2).

  • hidden_dim (int | None) – Number of hidden dimensions in projection head (defaults to output dimension of model).

  • output_dim (int | None) – Number of output dimensions in projection head (defaults to output dimension of model).

  • lr (float) – Learning rate (0.3 x batch_size / 256 is recommended).

  • momentum (float) – Momentum factor.

  • weight_decay (float) – Weight decay coefficient (1e-6 for v1, 1e-4 for v2).

  • temperature (float) – Temperature used in NT-Xent loss.

  • memory_bank_size (int) – Size of memory bank (0 for v1, 64K for v2).

  • gather_distributed (bool) – Gather negatives from all GPUs during distributed training (ignored if memory_bank_size > 0).

  • size (int) – Size of patch to crop.

  • grayscale_weights (torch.Tensor | None) – Weight vector for grayscale computation, see RandomGrayscale. Only used when augmentations=None. Defaults to average of all bands.

  • augmentations (torch.nn.modules.module.Module | None) – Data augmentation. Defaults to SimCLR augmentation.

Raises:

AssertionError – If an invalid version of SimCLR is requested.

Warns:

UserWarning – If hyperparameters do not match SimCLR version requested.

configure_models()[source]

Initialize the model.

configure_losses()[source]

Initialize the loss criterion.

forward(x)[source]

Forward pass of the model.

Parameters:

x (Tensor) – Mini-batch of images.

Returns:

Output of the model and backbone.

Return type:

tuple[torch.Tensor, torch.Tensor]

training_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the training loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Raises:

AssertionError – If channel dimensions are incorrect.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]

No-op, does nothing.

test_step(batch, batch_idx, dataloader_idx=0)[source]

No-op, does nothing.

predict_step(batch, batch_idx, dataloader_idx=0)[source]

No-op, does nothing.

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Changed in version 0.6: Changed from Adam to LARS optimizer.

Returns:

Optimizer and learning rate scheduler.

Return type:

OptimizerLRSchedulerConfig

class torchgeo.trainers.BaseTask(ignore=None)[source]

Bases: LightningModule, ABC

Abstract base class for all TorchGeo trainers.

New in version 0.5.

model: Any

Model to train.

monitor = 'val_loss'

Performance metric to monitor in learning rate scheduler and callbacks.

mode = 'min'

Whether the goal is to minimize or maximize the performance metric to monitor.

__init__(ignore=None)[source]

Initialize a new BaseTask instance.

Parameters:

ignore (collections.abc.Sequence[str] | str | None) – Arguments to skip when saving hyperparameters.

abstract configure_models()[source]

Initialize the model.

configure_losses()[source]

Initialize the loss criterion.

configure_metrics()[source]

Initialize the performance metrics.

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Returns:

Optimizer and learning rate scheduler.

Return type:

OptimizerLRSchedulerConfig

forward(*args, **kwargs)[source]

Forward pass of the model.

Parameters:
  • args (Any) – Arguments to pass to model.

  • kwargs (Any) – Keyword arguments to pass to model.

Returns:

Output of the model.

Return type:

Any

class torchgeo.trainers.IOBenchTask(ignore=None)[source]

Bases: BaseTask

I/O benchmarking.

New in version 0.6.

configure_models()[source]

No-op.

configure_optimizers()[source]

Initialize the optimizer.

Returns:

Optimizer.

Return type:

OptimizerLRSchedulerConfig

training_step(batch, batch_idx, dataloader_idx=0)[source]

No-op.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

Zero.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]

No-op.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

test_step(batch, batch_idx, dataloader_idx=0)[source]

No-op.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

predict_step(batch, batch_idx, dataloader_idx=0)[source]

No-op.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources