Shortcuts

Source code for torchgeo.trainers.base

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Base classes for all :mod:`torchgeo` trainers."""

from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Optional, Union

import lightning
from lightning.pytorch import LightningModule
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau


[docs]class BaseTask(LightningModule, ABC): """Abstract base class for all TorchGeo trainers. .. versionadded:: 0.5 """ #: Model to train. model: Any #: Performance metric to monitor in learning rate scheduler and callbacks. monitor = "val_loss" #: Whether the goal is to minimize or maximize the performance metric to monitor. mode = "min"
[docs] def __init__(self, ignore: Optional[Union[Sequence[str], str]] = None) -> None: """Initialize a new BaseTask instance. Args: ignore: Arguments to skip when saving hyperparameters. """ super().__init__() self.save_hyperparameters(ignore=ignore) self.configure_losses() self.configure_metrics() self.configure_models()
[docs] def configure_losses(self) -> None: """Initialize the loss criterion."""
[docs] def configure_metrics(self) -> None: """Initialize the performance metrics."""
[docs] @abstractmethod def configure_models(self) -> None: """Initialize the model."""
[docs] def configure_optimizers( self, ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig": """Initialize the optimizer and learning rate scheduler. Returns: Optimizer and learning rate scheduler. """ optimizer = AdamW(self.parameters(), lr=self.hparams["lr"]) scheduler = ReduceLROnPlateau(optimizer, patience=self.hparams["patience"]) return { "optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": self.monitor}, }
[docs] def forward(self, *args: Any, **kwargs: Any) -> Any: """Forward pass of the model. Args: args: Arguments to pass to model. kwargs: Keyword arguments to pass to model. Returns: Output of the model. """ return self.model(*args, **kwargs)

© Copyright 2021, Microsoft Corporation. Revision b9653beb.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
v0.5.2
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.1
v0.3.0
v0.2.1
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources