Shortcuts

torchgeo.trainers

TorchGeo trainers.

class torchgeo.trainers.BYOLTask(**kwargs)

Bases: LightningModule

Class for pre-training any PyTorch model using BYOL.

__init__(**kwargs)[source]

Initialize a LightningModule for pre-training a model with BYOL.

Keyword Arguments
  • in_channels – number of channels on the input imagery

  • encoder_name – either “resnet18” or “resnet50”

  • imagenet_pretraining – bool indicating whether to use imagenet pretrained weights

Raises

ValueError – if kwargs arguments are invalid

config_task()[source]

Configures the task based on kwargs parameters passed to the constructor.

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Returns

a “lr dict” according to the pytorch lightning documentation – https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers

Return type

Dict[str, Any]

forward(*args, **kwargs)[source]

Forward pass of the model.

Parameters

x – tensor of data to run through the model

Returns

output from the model

Return type

Any

test_step(*args, **kwargs)[source]

No-op, does nothing.

training_step(*args, **kwargs)[source]

Compute and return the training loss.

Parameters

batch – the output of your DataLoader

Returns

training loss

Return type

Tensor

validation_step(*args, **kwargs)[source]

Compute validation loss.

Parameters

batch – the output of your DataLoader

class torchgeo.trainers.ClassificationTask(**kwargs)

Bases: LightningModule

LightningModule for image classification.

__init__(**kwargs)[source]

Initialize the LightningModule with a model and loss function.

Keyword Arguments
  • classification_model – Name of the classification model use

  • loss – Name of the loss function

  • weights – Either “random”, “imagenet_only”, “imagenet_and_random”, or “random_rgb”

config_model()[source]

Configures the model based on kwargs parameters passed to the constructor.

config_task()[source]

Configures the task based on kwargs parameters passed to the constructor.

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Returns

a “lr dict” according to the pytorch lightning documentation – https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers

Return type

Dict[str, Any]

forward(*args, **kwargs)[source]

Forward pass of the model.

Parameters

x – input image

Returns

prediction

Return type

Any

test_epoch_end(outputs)[source]

Logs epoch level test metrics.

Parameters

outputs (Any) – list of items returned by test_step

test_step(*args, **kwargs)[source]

Compute test loss.

Parameters

batch – the output of your DataLoader

training_epoch_end(outputs)[source]

Logs epoch-level training metrics.

Parameters

outputs (Any) – list of items returned by training_step

training_step(*args, **kwargs)[source]

Compute and return the training loss.

Parameters

batch – the output of your DataLoader

Returns

training loss

Return type

Tensor

validation_epoch_end(outputs)[source]

Logs epoch level validation metrics.

Parameters

outputs (Any) – list of items returned by validation_step

validation_step(*args, **kwargs)[source]

Compute validation loss and log example predictions.

Parameters
  • batch – the output of your DataLoader

  • batch_idx – the index of this batch

class torchgeo.trainers.MultiLabelClassificationTask(**kwargs)

Bases: ClassificationTask

LightningModule for multi-label image classification.

__init__(**kwargs)[source]

Initialize the LightningModule with a model and loss function.

Keyword Arguments
  • classification_model – Name of the classification model use

  • loss – Name of the loss function

  • weights – Either “random”, “imagenet_only”, “imagenet_and_random”, or “random_rgb”

config_task()[source]

Configures the task based on kwargs parameters passed to the constructor.

test_step(*args, **kwargs)[source]

Compute test loss.

Parameters

batch – the output of your DataLoader

training_step(*args, **kwargs)[source]

Compute and return the training loss.

Parameters

batch – the output of your DataLoader

Returns

training loss

Return type

Tensor

validation_step(*args, **kwargs)[source]

Compute validation loss and log example predictions.

Parameters
  • batch – the output of your DataLoader

  • batch_idx – the index of this batch

class torchgeo.trainers.RegressionTask(**kwargs)

Bases: LightningModule

LightningModule for training models on regression datasets.

__init__(**kwargs)[source]

Initialize a new LightningModule for training simple regression models.

Keyword Arguments
  • model – Name of the model to use

  • learning_rate – Initial learning rate to use in the optimizer

  • learning_rate_schedule_patience – Patience parameter for the LR scheduler

config_task()[source]

Configures the task based on kwargs parameters.

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Returns

a “lr dict” according to the pytorch lightning documentation – https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers

Return type

Dict[str, Any]

forward(*args, **kwargs)[source]

Forward pass of the model.

Parameters

x – tensor of data to run through the model

Returns

output from the model

Return type

Any

test_epoch_end(outputs)[source]

Logs epoch level test metrics.

Parameters

outputs (Any) – list of items returned by test_step

test_step(*args, **kwargs)[source]

Compute test loss.

Parameters

batch – the output of your DataLoader

training_epoch_end(outputs)[source]

Logs epoch-level training metrics.

Parameters

outputs (Any) – list of items returned by training_step

training_step(*args, **kwargs)[source]

Compute and return the training loss.

Parameters

batch – the output of your DataLoader

Returns

training loss

Return type

Tensor

validation_epoch_end(outputs)[source]

Logs epoch level validation metrics.

Parameters

outputs (Any) – list of items returned by validation_step

validation_step(*args, **kwargs)[source]

Compute validation loss and log example predictions.

Parameters
  • batch – the output of your DataLoader

  • batch_idx – the index of this batch

class torchgeo.trainers.SemanticSegmentationTask(**kwargs)

Bases: LightningModule

LightningModule for semantic segmentation of images.

__init__(**kwargs)[source]

Initialize the LightningModule with a model and loss function.

Keyword Arguments
  • segmentation_model – Name of the segmentation model type to use

  • encoder_name – Name of the encoder model backbone to use

  • encoder_weights – None or “imagenet” to use imagenet pretrained weights in the encoder model

  • in_channels – Number of channels in input image

  • num_classes – Number of semantic classes to predict

  • loss – Name of the loss function

  • ignore_index – Optional integer class index to ignore in the loss and metrics

Raises

ValueError – if kwargs arguments are invalid

Changed in version 0.3: The ignore_zeros parameter was renamed to ignore_index.

config_task()[source]

Configures the task based on kwargs parameters passed to the constructor.

configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

Returns

a “lr dict” according to the pytorch lightning documentation – https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers

Return type

Dict[str, Any]

forward(*args, **kwargs)[source]

Forward pass of the model.

Parameters

x – tensor of data to run through the model

Returns

output from the model

Return type

Any

test_epoch_end(outputs)[source]

Logs epoch level test metrics.

Parameters

outputs (Any) – list of items returned by test_step

test_step(*args, **kwargs)[source]

Compute test loss.

Parameters

batch – the output of your DataLoader

training_epoch_end(outputs)[source]

Logs epoch level training metrics.

Parameters

outputs (Any) – list of items returned by training_step

training_step(*args, **kwargs)[source]

Compute and return the training loss.

Parameters

batch – the output of your DataLoader

Returns

training loss

Return type

Tensor

validation_epoch_end(outputs)[source]

Logs epoch level validation metrics.

Parameters

outputs (Any) – list of items returned by validation_step

validation_step(*args, **kwargs)[source]

Compute validation loss and log example predictions.

Parameters
  • batch – the output of your DataLoader

  • batch_idx – the index of this batch

Read the Docs v: stable
Versions
latest
stable
v0.3.1
v0.3.0
v0.2.1
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources