torchgeo.trainers¶
TorchGeo trainers.
- class torchgeo.trainers.BYOLTask(**kwargs)[source]¶
Bases:
LightningModule
Class for pre-training any PyTorch model using BYOL.
Supports any available Timm model as an architecture choice. To see a list of available pretrained models, you can do:
import timm print(timm.list_models())
- __init__(**kwargs)[source]¶
Initialize a LightningModule for pre-training a model with BYOL.
- Keyword Arguments:
in_channels – Number of input channels to model
backbone – Name of the timm model to use
weights – Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.
learning_rate – Learning rate for optimizer
learning_rate_schedule_patience – Patience for learning rate scheduler
- Raises:
ValueError – if kwargs arguments are invalid
Changed in version 0.4: The backbone_name parameter was renamed to backbone. Change backbone support from torchvision.models to timm.
- forward(*args, **kwargs)[source]¶
Forward pass of the model.
- Parameters:
x – tensor of data to run through the model
- Returns:
output from the model
- Return type:
- class torchgeo.trainers.ClassificationTask(**kwargs)[source]¶
Bases:
LightningModule
LightningModule for image classification.
Supports any available Timm model as an architecture choice. To see a list of available models, you can do:
import timm print(timm.list_models())
- __init__(**kwargs)[source]¶
Initialize the LightningModule with a model and loss function.
- Keyword Arguments:
model – Name of the classification model use
loss – Name of the loss function, accepts ‘ce’, ‘jaccard’, or ‘focal’
weights – Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.
num_classes – Number of prediction classes
in_channels – Number of input channels to model
learning_rate – Learning rate for optimizer
learning_rate_schedule_patience – Patience for learning rate scheduler
Changed in version 0.4: The classification_model parameter was renamed to model.
- forward(*args, **kwargs)[source]¶
Forward pass of the model.
- Parameters:
x – input image
- Returns:
prediction
- Return type:
- training_step(*args, **kwargs)[source]¶
Compute and return the training loss.
- Parameters:
batch – the output of your DataLoader
- Returns:
training loss
- Return type:
- validation_step(*args, **kwargs)[source]¶
Compute validation loss and log example predictions.
- Parameters:
batch – the output of your DataLoader
batch_idx – the index of this batch
- test_step(*args, **kwargs)[source]¶
Compute test loss.
- Parameters:
batch – the output of your DataLoader
- class torchgeo.trainers.MultiLabelClassificationTask(**kwargs)[source]¶
Bases:
ClassificationTask
LightningModule for multi-label image classification.
- __init__(**kwargs)[source]¶
Initialize the LightningModule with a model and loss function.
- Keyword Arguments:
model – Name of the classification model use
loss – Name of the loss function, currently only supports ‘bce’
weights – Either “random” or ‘imagenet’
num_classes – Number of prediction classes
in_channels – Number of input channels to model
learning_rate – Learning rate for optimizer
learning_rate_schedule_patience – Patience for learning rate scheduler
Changed in version 0.4: The classification_model parameter was renamed to model.
- training_step(*args, **kwargs)[source]¶
Compute and return the training loss.
- Parameters:
batch – the output of your DataLoader
- Returns:
training loss
- Return type:
- validation_step(*args, **kwargs)[source]¶
Compute validation loss and log example predictions.
- Parameters:
batch – the output of your DataLoader
batch_idx – the index of this batch
- class torchgeo.trainers.ObjectDetectionTask(**kwargs)[source]¶
Bases:
LightningModule
LightningModule for object detection of images.
Currently, supports Faster R-CNN, FCOS, and RetinaNet models from torchvision with one of the following backbone arguments:
['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d','resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2']
New in version 0.4.
- __init__(**kwargs)[source]¶
Initialize the LightningModule with a model and loss function.
- Keyword Arguments:
model – Name of the detection model type to use
backbone – Name of the model backbone to use
in_channels – Number of channels in input image
num_classes – Number of semantic classes to predict
learning_rate – Learning rate for optimizer
learning_rate_schedule_patience – Patience for learning rate scheduler
- Raises:
ValueError – if kwargs arguments are invalid
Changed in version 0.4: The detection_model parameter was renamed to model.
- forward(*args, **kwargs)[source]¶
Forward pass of the model.
- Parameters:
x – tensor of data to run through the model
- Returns:
output from the model
- Return type:
- training_step(*args, **kwargs)[source]¶
Compute and return the training loss.
- Parameters:
batch – the output of your DataLoader
- Returns:
training loss
- Return type:
- validation_step(*args, **kwargs)[source]¶
Compute validation loss and log example predictions.
- Parameters:
batch – the output of your DataLoader
batch_idx – the index of this batch
- test_step(*args, **kwargs)[source]¶
Compute test MAP.
- Parameters:
batch – the output of your DataLoader
- class torchgeo.trainers.RegressionTask(**kwargs)[source]¶
Bases:
LightningModule
LightningModule for training models on regression datasets.
Supports any available Timm model as an architecture choice. To see a list of available models, you can do:
import timm print(timm.list_models())
- __init__(**kwargs)[source]¶
Initialize a new LightningModule for training simple regression models.
- Keyword Arguments:
model – Name of the timm model to use
weights – Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.
num_outputs – Number of prediction outputs
in_channels – Number of input channels to model
learning_rate – Learning rate for optimizer
learning_rate_schedule_patience – Patience for learning rate scheduler
Changed in version 0.4: Change regression model support from torchvision.models to timm
- forward(*args, **kwargs)[source]¶
Forward pass of the model.
- Parameters:
x – tensor of data to run through the model
- Returns:
output from the model
- Return type:
- training_step(*args, **kwargs)[source]¶
Compute and return the training loss.
- Parameters:
batch – the output of your DataLoader
- Returns:
training loss
- Return type:
- validation_step(*args, **kwargs)[source]¶
Compute validation loss and log example predictions.
- Parameters:
batch – the output of your DataLoader
batch_idx – the index of this batch
- test_step(*args, **kwargs)[source]¶
Compute test loss.
- Parameters:
batch – the output of your DataLoader
- class torchgeo.trainers.SemanticSegmentationTask(**kwargs)[source]¶
Bases:
LightningModule
LightningModule for semantic segmentation of images.
Supports Segmentation Models Pytorch as an architecture choice in combination with any of these TIMM backbones.
- __init__(**kwargs)[source]¶
Initialize the LightningModule with a model and loss function.
- Keyword Arguments:
model – Name of the segmentation model type to use
backbone – Name of the timm backbone to use
weights – None or “imagenet” to use imagenet pretrained weights in the backbone
in_channels – Number of channels in input image
num_classes – Number of semantic classes to predict
loss – Name of the loss function, currently supports ‘ce’, ‘jaccard’ or ‘focal’ loss
ignore_index – Optional integer class index to ignore in the loss and metrics
learning_rate – Learning rate for optimizer
learning_rate_schedule_patience – Patience for learning rate scheduler
- Raises:
ValueError – if kwargs arguments are invalid
Changed in version 0.3: The ignore_zeros parameter was renamed to ignore_index.
Changed in version 0.4: The segmentation_model parameter was renamed to model, encoder_name renamed to backbone, and encoder_weights to weights.
- forward(*args, **kwargs)[source]¶
Forward pass of the model.
- Parameters:
x – tensor of data to run through the model
- Returns:
output from the model
- Return type:
- training_step(*args, **kwargs)[source]¶
Compute and return the training loss.
- Parameters:
batch – the output of your DataLoader
- Returns:
training loss
- Return type:
- validation_step(*args, **kwargs)[source]¶
Compute validation loss and log example predictions.
- Parameters:
batch – the output of your DataLoader
batch_idx – the index of this batch
- test_step(*args, **kwargs)[source]¶
Compute test loss.
- Parameters:
batch – the output of your DataLoader
- predict_step(*args, **kwargs)[source]¶
Compute and return the predictions.
By default, this will loop over images in a dataloader and aggregate predictions into a list. This may not be desirable if you have many images or large images which could cause out of memory errors. In this case it’s recommended to override this with a custom predict_step.
- Parameters:
batch – the output of your DataLoader
- Returns:
predicted softmax probabilities
- Return type: