Shortcuts

torchgeo.trainers

TorchGeo trainers.

class torchgeo.trainers.ClassificationTask(model='resnet50', weights=None, in_channels=3, num_classes=1000, loss='ce', class_weights=None, lr=0.001, patience=10, freeze_backbone=False)[source]

Bases: BaseTask

Image classification.

__init__(model='resnet50', weights=None, in_channels=3, num_classes=1000, loss='ce', class_weights=None, lr=0.001, patience=10, freeze_backbone=False)[source]

Initialize a new ClassificationTask instance.

Parameters:
  • model (str) – Name of the timm model to use.

  • weights (torchvision.models._api.WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.

  • in_channels (int) – Number of input channels to model.

  • num_classes (int) – Number of prediction classes.

  • loss (str) – One of ‘ce’, ‘bce’, ‘jaccard’, or ‘focal’.

  • class_weights (torch.Tensor | None) – Optional rescaling weight given to each class and used with ‘ce’ loss.

  • lr (float) – Learning rate for optimizer.

  • patience (int) – Patience for learning rate scheduler.

  • freeze_backbone (bool) – Freeze the backbone network to linear probe the classifier head.

Changed in version 0.4: classification_model was renamed to model.

New in version 0.5: The class_weights and freeze_backbone parameters.

Changed in version 0.5: learning_rate and learning_rate_schedule_patience were renamed to lr and patience.

configure_models()[source]

Initialize the model.

configure_losses()[source]

Initialize the loss criterion.

Raises:

ValueError – If loss is invalid.

configure_metrics()[source]

Initialize the performance metrics.

  • MulticlassAccuracy: The number of true positives divided by the dataset size. Both overall accuracy (OA) using ‘micro’ averaging and average accuracy (AA) using ‘macro’ averaging are reported. Higher values are better.

  • MulticlassJaccardIndex: Intersection over union (IoU). Uses ‘macro’ averaging. Higher valuers are better.

  • MulticlassFBetaScore: F1 score. The harmonic mean of precision and recall. Uses ‘micro’ averaging. Higher values are better.

Note

  • ‘Micro’ averaging suits overall performance evaluation but may not reflect minority class accuracy.

  • ‘Macro’ averaging gives equal weight to each class, and is useful for balanced performance assessment across imbalanced classes.

training_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the training loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the validation loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

test_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the test loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

predict_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the predicted class probabilities.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

Output predicted probabilities.

Return type:

Tensor

class torchgeo.trainers.MultiLabelClassificationTask(model='resnet50', weights=None, in_channels=3, num_classes=1000, loss='ce', class_weights=None, lr=0.001, patience=10, freeze_backbone=False)[source]

Bases: ClassificationTask

Multi-label image classification.

configure_metrics()[source]

Initialize the performance metrics.

  • MultilabelAccuracy: The number of true positives divided by the dataset size. Both overall accuracy (OA) using ‘micro’ averaging and average accuracy (AA) using ‘macro’ averaging are reported. Higher values are better.

  • MultilabelFBetaScore: F1 score. The harmonic mean of precision and recall. Uses ‘micro’ averaging. Higher values are better.

Note

  • ‘Micro’ averaging suits overall performance evaluation but may not reflect minority class accuracy.

  • ‘Macro’ averaging gives equal weight to each class, and is useful for balanced performance assessment across imbalanced classes.

training_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the training loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the validation loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

test_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the test loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

predict_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the predicted class probabilities.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

Output predicted probabilities.

Return type:

Tensor

class torchgeo.trainers.ObjectDetectionTask(model='faster-rcnn', backbone='resnet50', weights=None, in_channels=3, num_classes=1000, trainable_layers=3, lr=0.001, patience=10, freeze_backbone=False)[source]

Bases: BaseTask

Object detection.

New in version 0.4.

monitor = 'val_map'

Performance metric to monitor in learning rate scheduler and callbacks.

mode = 'max'

Whether the goal is to minimize or maximize the performance metric to monitor.

__init__(model='faster-rcnn', backbone='resnet50', weights=None, in_channels=3, num_classes=1000, trainable_layers=3, lr=0.001, patience=10, freeze_backbone=False)[source]

Initialize a new ObjectDetectionTask instance.

Parameters:
  • model (str) – Name of the torchvision model to use. One of ‘faster-rcnn’, ‘fcos’, or ‘retinanet’.

  • backbone (str) – Name of the torchvision backbone to use. One of ‘resnet18’, ‘resnet34’, ‘resnet50’, ‘resnet101’, ‘resnet152’, ‘resnext50_32x4d’, ‘resnext101_32x8d’, ‘wide_resnet50_2’, or ‘wide_resnet101_2’.

  • weights (bool | None) – Initial model weights. True for ImageNet weights, False or None for random weights.

  • in_channels (int) – Number of input channels to model.

  • num_classes (int) – Number of prediction classes.

  • trainable_layers (int) – Number of trainable layers.

  • lr (float) – Learning rate for optimizer.

  • patience (int) – Patience for learning rate scheduler.

  • freeze_backbone (bool) – Freeze the backbone network to fine-tune the detection head.

Changed in version 0.4: detection_model was renamed to model.

New in version 0.5: The freeze_backbone parameter.

Changed in version 0.5: pretrained, learning_rate, and learning_rate_schedule_patience were renamed to weights, lr, and patience.

configure_models()[source]

Initialize the model.

Raises:

ValueError – If model or backbone are invalid.

configure_metrics()[source]

Initialize the performance metrics.

  • MeanAveragePrecision: Mean average precision (mAP) and mean average recall (mAR). Precision is the number of true positives divided by the number of true positives + false positives. Recall is the number of true positives divived by the number of true positives + false negatives. Uses ‘macro’ averaging. Higher values are better.

Note

  • ‘Micro’ averaging suits overall performance evaluation but may not reflect minority class accuracy.

  • ‘Macro’ averaging gives equal weight to each class, and is useful for balanced performance assessment across imbalanced classes.

training_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the training loss.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the validation metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

test_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the test metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

predict_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the predicted bounding boxes.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

Output predicted probabilities.

Return type:

list[dict[str, torch.Tensor]]

class torchgeo.trainers.PixelwiseRegressionTask(model='resnet50', backbone='resnet50', weights=None, in_channels=3, num_outputs=1, num_filters=3, loss='mse', lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]

Bases: RegressionTask

LightningModule for pixelwise regression of images.

New in version 0.5.

configure_models()[source]

Initialize the model.

class torchgeo.trainers.RegressionTask(model='resnet50', backbone='resnet50', weights=None, in_channels=3, num_outputs=1, num_filters=3, loss='mse', lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]

Bases: BaseTask

Regression.

__init__(model='resnet50', backbone='resnet50', weights=None, in_channels=3, num_outputs=1, num_filters=3, loss='mse', lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]

Initialize a new RegressionTask instance.

Parameters:
  • model (str) – Name of the timm or smp model to use.

  • backbone (str) – Name of the timm or smp backbone to use. Only applicable to PixelwiseRegressionTask.

  • weights (torchvision.models._api.WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.

  • in_channels (int) – Number of input channels to model.

  • num_outputs (int) – Number of prediction outputs.

  • num_filters (int) – Number of filters. Only applicable when model=’fcn’.

  • loss (str) – One of ‘mse’ or ‘mae’.

  • lr (float) – Learning rate for optimizer.

  • patience (int) – Patience for learning rate scheduler.

  • freeze_backbone (bool) – Freeze the backbone network to linear probe the regression head. Does not support FCN models.

  • freeze_decoder (bool) – Freeze the decoder network to linear probe the regression head. Does not support FCN models. Only applicable to PixelwiseRegressionTask.

Changed in version 0.4: Change regression model support from torchvision.models to timm

New in version 0.5: The freeze_backbone and freeze_decoder parameters.

Changed in version 0.5: learning_rate and learning_rate_schedule_patience were renamed to lr and patience.

configure_models()[source]

Initialize the model.

configure_losses()[source]

Initialize the loss criterion.

Raises:

ValueError – If loss is invalid.

configure_metrics()[source]

Initialize the performance metrics.

  • MeanSquaredError: The average of the squared differences between the predicted and actual values (MSE) and its square root (RMSE). Lower values are better.

  • MeanAbsoluteError: The average of the absolute differences between the predicted and actual values (MAE). Lower values are better.

training_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the training loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the validation loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

test_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the test loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

predict_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the predicted regression values.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

Output predicted probabilities.

Return type:

Tensor

class torchgeo.trainers.SemanticSegmentationTask(model='unet', backbone='resnet50', weights=None, in_channels=3, num_classes=1000, num_filters=3, loss='ce', class_weights=None, ignore_index=None, lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]

Bases: BaseTask

Semantic Segmentation.

__init__(model='unet', backbone='resnet50', weights=None, in_channels=3, num_classes=1000, num_filters=3, loss='ce', class_weights=None, ignore_index=None, lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]

Initialize a new SemanticSegmentationTask instance.

Parameters:
  • model (str) – Name of the smp model to use.

  • backbone (str) – Name of the timm or smp backbone to use.

  • weights (torchvision.models._api.WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict. FCN model does not support pretrained weights. Pretrained ViT weight enums are not supported yet.

  • in_channels (int) – Number of input channels to model.

  • num_classes (int) – Number of prediction classes.

  • num_filters (int) – Number of filters. Only applicable when model=’fcn’.

  • loss (str) – Name of the loss function, currently supports ‘ce’, ‘jaccard’ or ‘focal’ loss.

  • class_weights (torch.Tensor | None) – Optional rescaling weight given to each class and used with ‘ce’ loss.

  • ignore_index (int | None) – Optional integer class index to ignore in the loss and metrics.

  • lr (float) – Learning rate for optimizer.

  • patience (int) – Patience for learning rate scheduler.

  • freeze_backbone (bool) – Freeze the backbone network to fine-tune the decoder and segmentation head.

  • freeze_decoder (bool) – Freeze the decoder network to linear probe the segmentation head.

Changed in version 0.3: ignore_zeros was renamed to ignore_index.

Changed in version 0.4: segmentation_model, encoder_name, and encoder_weights were renamed to model, backbone, and weights.

Changed in version 0.5: The weights parameter now supports WeightEnums and checkpoint paths. learning_rate and learning_rate_schedule_patience were renamed to lr and patience.

Changed in version 0.6: The ignore_index parameter now works for jaccard loss.

configure_models()[source]

Initialize the model.

Raises:

ValueError – If model is invalid.

configure_losses()[source]

Initialize the loss criterion.

Raises:

ValueError – If loss is invalid.

configure_metrics()[source]

Initialize the performance metrics.

  • MulticlassAccuracy: Overall accuracy (OA) using ‘micro’ averaging. The number of true positives divided by the dataset size. Higher values are better.

  • MulticlassJaccardIndex: Intersection over union (IoU). Uses ‘micro’ averaging. Higher valuers are better.

Note

  • ‘Micro’ averaging suits overall performance evaluation but may not reflect minority class accuracy.

  • ‘Macro’ averaging, not used here, gives equal weight to each class, useful for balanced performance assessment across imbalanced classes.

training_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the training loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the validation loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

test_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the test loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

predict_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the predicted class probabilities.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

Output predicted probabilities.

Return type:

Tensor

class torchgeo.trainers.BYOLTask(model='resnet50', weights=None, in_channels=3, lr=0.001, patience=10)[source]

Bases: BaseTask

BYOL: Bootstrap Your Own Latent.

Reference implementation:

If you use this trainer in your research, please cite the following paper:

monitor = 'train_loss'

Performance metric to monitor in learning rate scheduler and callbacks.

__init__(model='resnet50', weights=None, in_channels=3, lr=0.001, patience=10)[source]

Initialize a new BYOLTask instance.

Parameters:
  • model (str) – Name of the timm model to use.

  • weights (torchvision.models._api.WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.

  • in_channels (int) – Number of input channels to model.

  • lr (float) – Learning rate for optimizer.

  • patience (int) – Patience for learning rate scheduler.

Changed in version 0.4: backbone_name was renamed to backbone. Changed backbone support from torchvision.models to timm.

Changed in version 0.5: backbone, learning_rate, and learning_rate_schedule_patience were renamed to model, lr, and patience.

configure_models()[source]

Initialize the model.

training_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the training loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Raises:

AssertionError – If channel dimensions are incorrect.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]

No-op, does nothing.

test_step(batch, batch_idx, dataloader_idx=0)[source]

No-op, does nothing.

predict_step(batch, batch_idx, dataloader_idx=0)[source]

No-op, does nothing.

class torchgeo.trainers.MoCoTask(model='resnet50', weights=None, in_channels=3, version=3, layers=3, hidden_dim=4096, output_dim=256, lr=9.6, weight_decay=1e-06, momentum=0.9, schedule=[120, 160], temperature=1, memory_bank_size=0, moco_momentum=0.99, gather_distributed=False, size=224, grayscale_weights=None, augmentation1=None, augmentation2=None)[source]

Bases: BaseTask

MoCo: Momentum Contrast.

Reference implementations:

If you use this trainer in your research, please cite the following papers:

New in version 0.5.

monitor = 'train_loss'

Performance metric to monitor in learning rate scheduler and callbacks.

__init__(model='resnet50', weights=None, in_channels=3, version=3, layers=3, hidden_dim=4096, output_dim=256, lr=9.6, weight_decay=1e-06, momentum=0.9, schedule=[120, 160], temperature=1, memory_bank_size=0, moco_momentum=0.99, gather_distributed=False, size=224, grayscale_weights=None, augmentation1=None, augmentation2=None)[source]

Initialize a new MoCoTask instance.

Parameters:
  • model (str) – Name of the timm model to use.

  • weights (torchvision.models._api.WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.

  • in_channels (int) – Number of input channels to model.

  • version (int) – Version of MoCo, 1–3.

  • layers (int) – Number of layers in projection head (not used in v1, 2 for v1/2, 3 for v3).

  • hidden_dim (int) – Number of hidden dimensions in projection head (not used in v1, 2048 for v2, 4096 for v3).

  • output_dim (int) – Number of output dimensions in projection head (not used in v1, 128 for v2, 256 for v3).

  • lr (float) – Learning rate (0.03 x batch_size / 256 for v1/2, 0.6 x batch_size / 256 for v3).

  • weight_decay (float) – Weight decay coefficient (1e-4 for v1/2, 1e-6 for v3).

  • momentum (float) – Momentum of SGD solver (v1/2 only).

  • schedule (Sequence[int]) – Epochs at which to drop lr by 10x (v1/2 only).

  • temperature (float) – Temperature used in InfoNCE loss (0.07 for v1/2, 1 for v3).

  • memory_bank_size (int) – Size of memory bank (65536 for v1/2, 0 for v3).

  • moco_momentum (float) – MoCo momentum of updating key encoder (0.999 for v1/2, 0.99 for v3)

  • gather_distributed (bool) – Gather negatives from all GPUs during distributed training (ignored if memory_bank_size > 0).

  • size (int) – Size of patch to crop.

  • grayscale_weights (torch.Tensor | None) – Weight vector for grayscale computation, see RandomGrayscale. Only used when augmentations=None. Defaults to average of all bands.

  • augmentation1 (torch.nn.modules.module.Module | None) – Data augmentation for 1st branch. Defaults to MoCo augmentation.

  • augmentation2 (torch.nn.modules.module.Module | None) – Data augmentation for 2nd branch. Defaults to MoCo augmentation.

Raises:

AssertionError – If an invalid version of MoCo is requested.

Warns:

UserWarning – If hyperparameters do not match MoCo version requested.

configure_models()[source]

Initialize the model.

configure_losses()[source]

Initialize the loss criterion.

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Returns:

Optimizer and learning rate scheduler.

Return type:

OptimizerLRSchedulerConfig

forward(x)[source]

Forward pass of the model.

Parameters:

x (Tensor) – Mini-batch of images.

Returns:

Output of the model and backbone

Return type:

tuple[torch.Tensor, torch.Tensor]

forward_momentum(x)[source]

Forward pass of the momentum model.

Parameters:

x (Tensor) – Mini-batch of images.

Returns:

Output from the momentum model.

Return type:

Tensor

training_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the training loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]

No-op, does nothing.

test_step(batch, batch_idx, dataloader_idx=0)[source]

No-op, does nothing.

predict_step(batch, batch_idx, dataloader_idx=0)[source]

No-op, does nothing.

class torchgeo.trainers.SimCLRTask(model='resnet50', weights=None, in_channels=3, version=2, layers=3, hidden_dim=None, output_dim=None, lr=4.8, weight_decay=0.0001, temperature=0.07, memory_bank_size=64000, gather_distributed=False, size=224, grayscale_weights=None, augmentations=None)[source]

Bases: BaseTask

SimCLR: a simple framework for contrastive learning of visual representations.

Reference implementation:

If you use this trainer in your research, please cite the following papers:

New in version 0.5.

monitor = 'train_loss'

Performance metric to monitor in learning rate scheduler and callbacks.

__init__(model='resnet50', weights=None, in_channels=3, version=2, layers=3, hidden_dim=None, output_dim=None, lr=4.8, weight_decay=0.0001, temperature=0.07, memory_bank_size=64000, gather_distributed=False, size=224, grayscale_weights=None, augmentations=None)[source]

Initialize a new SimCLRTask instance.

Parameters:
  • model (str) – Name of the timm model to use.

  • weights (torchvision.models._api.WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.

  • in_channels (int) – Number of input channels to model.

  • version (int) – Version of SimCLR, 1–2.

  • layers (int) – Number of layers in projection head (2 for v1, 3+ for v2).

  • hidden_dim (int | None) – Number of hidden dimensions in projection head (defaults to output dimension of model).

  • output_dim (int | None) – Number of output dimensions in projection head (defaults to output dimension of model).

  • lr (float) – Learning rate (0.3 x batch_size / 256 is recommended).

  • weight_decay (float) – Weight decay coefficient (1e-6 for v1, 1e-4 for v2).

  • temperature (float) – Temperature used in NT-Xent loss.

  • memory_bank_size (int) – Size of memory bank (0 for v1, 64K for v2).

  • gather_distributed (bool) – Gather negatives from all GPUs during distributed training (ignored if memory_bank_size > 0).

  • size (int) – Size of patch to crop.

  • grayscale_weights (torch.Tensor | None) – Weight vector for grayscale computation, see RandomGrayscale. Only used when augmentations=None. Defaults to average of all bands.

  • augmentations (torch.nn.modules.module.Module | None) – Data augmentation. Defaults to SimCLR augmentation.

Raises:

AssertionError – If an invalid version of SimCLR is requested.

Warns:

UserWarning – If hyperparameters do not match SimCLR version requested.

configure_models()[source]

Initialize the model.

configure_losses()[source]

Initialize the loss criterion.

forward(x)[source]

Forward pass of the model.

Parameters:

x (Tensor) – Mini-batch of images.

Returns:

Output of the model and backbone.

Return type:

tuple[torch.Tensor, torch.Tensor]

training_step(batch, batch_idx, dataloader_idx=0)[source]

Compute the training loss and additional metrics.

Parameters:
  • batch (Any) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Raises:

AssertionError – If channel dimensions are incorrect.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]

No-op, does nothing.

test_step(batch, batch_idx, dataloader_idx=0)[source]

No-op, does nothing.

predict_step(batch, batch_idx, dataloader_idx=0)[source]

No-op, does nothing.

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Returns:

Optimizer and learning rate scheduler.

Return type:

OptimizerLRSchedulerConfig

class torchgeo.trainers.BaseTask(ignore=None)[source]

Bases: LightningModule, ABC

Abstract base class for all TorchGeo trainers.

New in version 0.5.

model: Any

Model to train.

monitor = 'val_loss'

Performance metric to monitor in learning rate scheduler and callbacks.

mode = 'min'

Whether the goal is to minimize or maximize the performance metric to monitor.

__init__(ignore=None)[source]

Initialize a new BaseTask instance.

Parameters:

ignore (collections.abc.Sequence[str] | str | None) – Arguments to skip when saving hyperparameters.

abstract configure_models()[source]

Initialize the model.

configure_losses()[source]

Initialize the loss criterion.

configure_metrics()[source]

Initialize the performance metrics.

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Returns:

Optimizer and learning rate scheduler.

Return type:

OptimizerLRSchedulerConfig

forward(*args, **kwargs)[source]

Forward pass of the model.

Parameters:
  • args (Any) – Arguments to pass to model.

  • kwargs (Any) – Keyword arguments to pass to model.

Returns:

Output of the model.

Return type:

Any

Read the Docs v: latest
Versions
latest
stable
v0.5.2
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.1
v0.3.0
v0.2.1
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources