Shortcuts

Source code for torchgeo.trainers.regression

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Regression tasks."""

from typing import Any, Dict

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules import Conv2d, Linear
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection
from torchvision import models

from ..datasets.utils import unbind_samples

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Conv2d.__module__ = "nn.Conv2d"
Linear.__module__ = "nn.Linear"


class RegressionTask(pl.LightningModule):
    """LightningModule for training models on regression datasets."""

[docs] def config_task(self) -> None: """Configures the task based on kwargs parameters.""" if self.hparams["model"] == "resnet18": self.model = models.resnet18(pretrained=self.hparams["pretrained"]) in_features = self.model.fc.in_features self.model.fc = nn.Linear( # type: ignore[attr-defined] in_features, out_features=1 ) else: raise ValueError(f"Model type '{self.hparams['model']}' is not valid.")
[docs] def __init__(self, **kwargs: Any) -> None: """Initialize a new LightningModule for training simple regression models. Keyword Args: model: Name of the model to use learning_rate: Initial learning rate to use in the optimizer learning_rate_schedule_patience: Patience parameter for the LR scheduler """ super().__init__() self.save_hyperparameters() # creates `self.hparams` from kwargs self.config_task() self.train_metrics = MetricCollection( {"RMSE": MeanSquaredError(squared=False), "MAE": MeanAbsoluteError()}, prefix="train_", ) self.val_metrics = self.train_metrics.clone(prefix="val_") self.test_metrics = self.train_metrics.clone(prefix="test_")
[docs] def forward(self, x: Tensor) -> Any: # type: ignore[override] """Forward pass of the model.""" return self.model(x)
[docs] def training_step( # type: ignore[override] self, batch: Dict[str, Any], batch_idx: int ) -> Tensor: """Training step with an MSE loss. Args: batch: Current batch batch_idx: Index of current batch Returns: training loss """ x = batch["image"] y = batch["label"].view(-1, 1) y_hat = self.forward(x) loss = F.mse_loss(y_hat, y) self.log("train_loss", loss) # logging to TensorBoard self.train_metrics(y_hat, y) return loss
[docs] def training_epoch_end(self, outputs: Any) -> None: """Logs epoch-level training metrics. Args: outputs: list of items returned by training_step """ self.log_dict(self.train_metrics.compute()) self.train_metrics.reset()
[docs] def validation_step( # type: ignore[override] self, batch: Dict[str, Any], batch_idx: int ) -> None: """Validation step. Args: batch: Current batch batch_idx: Index of current batch """ x = batch["image"] y = batch["label"].view(-1, 1) y_hat = self.forward(x) loss = F.mse_loss(y_hat, y) self.log("val_loss", loss) self.val_metrics(y_hat, y) if batch_idx < 10: try: datamodule = self.trainer.datamodule # type: ignore[attr-defined] batch["prediction"] = y_hat for key in ["image", "label", "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] fig = datamodule.plot(sample) summary_writer = self.logger.experiment summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step ) except AttributeError: pass
[docs] def validation_epoch_end(self, outputs: Any) -> None: """Logs epoch level validation metrics. Args: outputs: list of items returned by validation_step """ self.log_dict(self.val_metrics.compute()) self.val_metrics.reset()
[docs] def test_step( # type: ignore[override] self, batch: Dict[str, Any], batch_idx: int ) -> None: """Test step. Args: batch: Current batch batch_idx: Index of current batch """ x = batch["image"] y = batch["label"].view(-1, 1) y_hat = self.forward(x) loss = F.mse_loss(y_hat, y) self.log("test_loss", loss) self.test_metrics(y_hat, y)
[docs] def test_epoch_end(self, outputs: Any) -> None: """Logs epoch level test metrics. Args: outputs: list of items returned by test_step """ self.log_dict(self.test_metrics.compute()) self.test_metrics.reset()
[docs] def configure_optimizers(self) -> Dict[str, Any]: """Initialize the optimizer and learning rate scheduler. Returns: a "lr dict" according to the pytorch lightning documentation -- https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers """ optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.hparams["learning_rate"] ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau( optimizer, patience=self.hparams["learning_rate_schedule_patience"] ), "monitor": "val_loss", }, }

© Copyright 2021, Microsoft Corporation. Revision e1285e6c.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources