Source code for torchgeo.trainers.base
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Base classes for all :mod:`torchgeo` trainers."""
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any
import lightning
from lightning.pytorch import LightningModule
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
[docs]class BaseTask(LightningModule, ABC):
"""Abstract base class for all TorchGeo trainers.
.. versionadded:: 0.5
"""
#: Model to train.
model: Any
#: Performance metric to monitor in learning rate scheduler and callbacks.
monitor = 'val_loss'
#: Whether the goal is to minimize or maximize the performance metric to monitor.
mode = 'min'
[docs] def __init__(self, ignore: Sequence[str] | str | None = None) -> None:
"""Initialize a new BaseTask instance.
Args:
ignore: Arguments to skip when saving hyperparameters.
"""
super().__init__()
self.save_hyperparameters(ignore=ignore)
self.configure_models()
self.configure_losses()
self.configure_metrics()
[docs] def configure_optimizers(
self,
) -> 'lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig':
"""Initialize the optimizer and learning rate scheduler.
Returns:
Optimizer and learning rate scheduler.
"""
optimizer = AdamW(self.parameters(), lr=self.hparams['lr'])
scheduler = ReduceLROnPlateau(optimizer, patience=self.hparams['patience'])
return {
'optimizer': optimizer,
'lr_scheduler': {'scheduler': scheduler, 'monitor': self.monitor},
}
[docs] def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Forward pass of the model.
Args:
args: Arguments to pass to model.
kwargs: Keyword arguments to pass to model.
Returns:
Output of the model.
"""
return self.model(*args, **kwargs)