Shortcuts

torchgeo.trainers

TorchGeo trainers.

class torchgeo.trainers.BYOLTask(**kwargs)

Bases: pytorch_lightning.core.lightning.LightningModule

Class for pre-training any PyTorch model using BYOL.

__init__(**kwargs)[source]

Initialize a LightningModule for pre-training a model with BYOL.

Keyword Arguments
  • in_channels – number of channels on the input imagery

  • encoder_name – either “resnet18” or “resnet50”

  • imagenet_pretraining – bool indicating whether to use imagenet pretrained weights

Raises

ValueError – if kwargs arguments are invalid

config_task()[source]

Configures the task based on kwargs parameters passed to the constructor.

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Returns

a “lr dict” according to the pytorch lightning documentation – https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers

Return type

Dict[str, Any]

forward(x)[source]

Forward pass of the model.

Parameters

x (torch.Tensor) – tensor of data to run through the model

Returns

output from the model

Return type

Any

test_step(*args)[source]

No-op, does nothing.

training_step(batch, batch_idx)[source]

Training step - reports BYOL loss.

Parameters
  • batch (Dict[str, Any]) – current batch

  • batch_idx (int) – index of current batch

Returns

training loss

Return type

torch.Tensor

validation_step(batch, batch_idx)[source]

Logs iteration level validation loss.

Parameters
  • batch (Dict[str, Any]) – current batch

  • batch_idx (int) – index of current batch

class torchgeo.trainers.ClassificationTask(**kwargs)

Bases: pytorch_lightning.core.lightning.LightningModule

LightningModule for image classification.

__init__(**kwargs)[source]

Initialize the LightningModule with a model and loss function.

Keyword Arguments
  • classification_model – Name of the classification model use

  • loss – Name of the loss function

  • weights – Either “random”, “imagenet_only”, “imagenet_and_random”, or “random_rgb”

config_model()[source]

Configures the model based on kwargs parameters passed to the constructor.

config_task()[source]

Configures the task based on kwargs parameters passed to the constructor.

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Returns

a “lr dict” according to the pytorch lightning documentation – https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers

Return type

Dict[str, Any]

forward(x)[source]

Forward pass of the model.

Parameters

x (torch.Tensor) – input image

Returns

prediction

Return type

Any

test_epoch_end(outputs)[source]

Logs epoch level test metrics.

Parameters

outputs (Any) – list of items returned by test_step

test_step(batch, batch_idx)[source]

Test step.

Parameters
  • batch (Dict[str, Any]) – Current batch

  • batch_idx (int) – Index of current batch

training_epoch_end(outputs)[source]

Logs epoch-level training metrics.

Parameters

outputs (Any) – list of items returned by training_step

training_step(batch, batch_idx)[source]

Training step.

Parameters
  • batch (Dict[str, Any]) – Current batch

  • batch_idx (int) – Index of current batch

Returns

training loss

Return type

torch.Tensor

validation_epoch_end(outputs)[source]

Logs epoch level validation metrics.

Parameters

outputs (Any) – list of items returned by validation_step

validation_step(batch, batch_idx)[source]

Validation step.

Parameters
  • batch (Dict[str, Any]) – Current batch

  • batch_idx (int) – Index of current batch

class torchgeo.trainers.MultiLabelClassificationTask(**kwargs)

Bases: torchgeo.trainers.ClassificationTask

LightningModule for multi-label image classification.

__init__(**kwargs)[source]

Initialize the LightningModule with a model and loss function.

Keyword Arguments
  • classification_model – Name of the classification model use

  • loss – Name of the loss function

  • weights – Either “random”, “imagenet_only”, “imagenet_and_random”, or “random_rgb”

config_task()[source]

Configures the task based on kwargs parameters passed to the constructor.

test_step(batch, batch_idx)[source]

Test step.

Parameters
  • batch (Dict[str, Any]) – Current batch

  • batch_idx (int) – Index of current batch

training_step(batch, batch_idx)[source]

Training step.

Parameters
  • batch (Dict[str, Any]) – Current batch

  • batch_idx (int) – Index of current batch

Returns

training loss

Return type

torch.Tensor

validation_step(batch, batch_idx)[source]

Validation step.

Parameters
  • batch (Dict[str, Any]) – Current batch

  • batch_idx (int) – Index of current batch

class torchgeo.trainers.RegressionTask(**kwargs)

Bases: pytorch_lightning.core.lightning.LightningModule

LightningModule for training models on regression datasets.

__init__(**kwargs)[source]

Initialize a new LightningModule for training simple regression models.

Keyword Arguments
  • model – Name of the model to use

  • learning_rate – Initial learning rate to use in the optimizer

  • learning_rate_schedule_patience – Patience parameter for the LR scheduler

config_task()[source]

Configures the task based on kwargs parameters.

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Returns

a “lr dict” according to the pytorch lightning documentation – https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers

Return type

Dict[str, Any]

forward(x)[source]

Forward pass of the model.

test_epoch_end(outputs)[source]

Logs epoch level test metrics.

Parameters

outputs (Any) – list of items returned by test_step

test_step(batch, batch_idx)[source]

Test step.

Parameters
  • batch (Dict[str, Any]) – Current batch

  • batch_idx (int) – Index of current batch

training_epoch_end(outputs)[source]

Logs epoch-level training metrics.

Parameters

outputs (Any) – list of items returned by training_step

training_step(batch, batch_idx)[source]

Training step with an MSE loss.

Parameters
  • batch (Dict[str, Any]) – Current batch

  • batch_idx (int) – Index of current batch

Returns

training loss

Return type

torch.Tensor

validation_epoch_end(outputs)[source]

Logs epoch level validation metrics.

Parameters

outputs (Any) – list of items returned by validation_step

validation_step(batch, batch_idx)[source]

Validation step.

Parameters
  • batch (Dict[str, Any]) – Current batch

  • batch_idx (int) – Index of current batch

class torchgeo.trainers.SemanticSegmentationTask(**kwargs)

Bases: pytorch_lightning.core.lightning.LightningModule

LightningModule for semantic segmentation of images.

__init__(**kwargs)[source]

Initialize the LightningModule with a model and loss function.

Keyword Arguments
  • segmentation_model – Name of the segmentation model type to use

  • encoder_name – Name of the encoder model backbone to use

  • encoder_weights – None or “imagenet” to use imagenet pretrained weights in the encoder model

  • in_channels – Number of channels in input image

  • num_classes – Number of semantic classes to predict

  • loss – Name of the loss function

  • ignore_zeros – Whether to ignore the “0” class value in the loss and metrics

Raises

ValueError – if kwargs arguments are invalid

config_task()[source]

Configures the task based on kwargs parameters passed to the constructor.

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Returns

a “lr dict” according to the pytorch lightning documentation – https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers

Return type

Dict[str, Any]

forward(x)[source]

Forward pass of the model.

Parameters

x (torch.Tensor) – tensor of data to run through the model

Returns

output from the model

Return type

Any

test_epoch_end(outputs)[source]

Logs epoch level test metrics.

Parameters

outputs (Any) – list of items returned by test_step

test_step(batch, batch_idx)[source]

Test step identical to the validation step.

Parameters
  • batch (Dict[str, Any]) – Current batch

  • batch_idx (int) – Index of current batch

training_epoch_end(outputs)[source]

Logs epoch level training metrics.

Parameters

outputs (Any) – list of items returned by training_step

training_step(batch, batch_idx)[source]

Training step - reports average accuracy and average IoU.

Parameters
  • batch (Dict[str, Any]) – Current batch

  • batch_idx (int) – Index of current batch

Returns

training loss

Return type

torch.Tensor

validation_epoch_end(outputs)[source]

Logs epoch level validation metrics.

Parameters

outputs (Any) – list of items returned by validation_step

validation_step(batch, batch_idx)[source]

Validation step - reports average accuracy and average IoU.

Logs the first 10 validation samples to tensorboard as images with 3 subplots showing the image, mask, and predictions.

Parameters
  • batch (Dict[str, Any]) – Current batch

  • batch_idx (int) – Index of current batch

Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources