Shortcuts

Source code for torchgeo.trainers.segmentation

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Segmentation tasks."""

import warnings
from typing import Any, Dict, cast

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex

from ..datasets.utils import unbind_samples
from ..models import FCN


[docs]class SemanticSegmentationTask(pl.LightningModule): """LightningModule for semantic segmentation of images. Supports `Segmentation Models Pytorch <https://github.com/qubvel/segmentation_models.pytorch>`_ as an architecture choice in combination with any of these `TIMM backbones <https://smp.readthedocs.io/en/latest/encoders_timm.html>`_. """
[docs] def config_task(self) -> None: """Configures the task based on kwargs parameters passed to the constructor.""" if self.hyperparams["model"] == "unet": self.model = smp.Unet( encoder_name=self.hyperparams["backbone"], encoder_weights=self.hyperparams["weights"], in_channels=self.hyperparams["in_channels"], classes=self.hyperparams["num_classes"], ) elif self.hyperparams["model"] == "deeplabv3+": self.model = smp.DeepLabV3Plus( encoder_name=self.hyperparams["backbone"], encoder_weights=self.hyperparams["weights"], in_channels=self.hyperparams["in_channels"], classes=self.hyperparams["num_classes"], ) elif self.hyperparams["model"] == "fcn": self.model = FCN( in_channels=self.hyperparams["in_channels"], classes=self.hyperparams["num_classes"], num_filters=self.hyperparams["num_filters"], ) else: raise ValueError( f"Model type '{self.hyperparams['model']}' is not valid. " f"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." ) if self.hyperparams["loss"] == "ce": ignore_value = -1000 if self.ignore_index is None else self.ignore_index self.loss = nn.CrossEntropyLoss(ignore_index=ignore_value) elif self.hyperparams["loss"] == "jaccard": self.loss = smp.losses.JaccardLoss( mode="multiclass", classes=self.hyperparams["num_classes"] ) elif self.hyperparams["loss"] == "focal": self.loss = smp.losses.FocalLoss( "multiclass", ignore_index=self.ignore_index, normalized=True ) else: raise ValueError( f"Loss type '{self.hyperparams['loss']}' is not valid. " f"Currently, supports 'ce', 'jaccard' or 'focal' loss." )
[docs] def __init__(self, **kwargs: Any) -> None: """Initialize the LightningModule with a model and loss function. Keyword Args: model: Name of the segmentation model type to use backbone: Name of the timm backbone to use weights: None or "imagenet" to use imagenet pretrained weights in the backbone in_channels: Number of channels in input image num_classes: Number of semantic classes to predict loss: Name of the loss function, currently supports 'ce', 'jaccard' or 'focal' loss ignore_index: Optional integer class index to ignore in the loss and metrics learning_rate: Learning rate for optimizer learning_rate_schedule_patience: Patience for learning rate scheduler Raises: ValueError: if kwargs arguments are invalid .. versionchanged:: 0.3 The *ignore_zeros* parameter was renamed to *ignore_index*. .. versionchanged:: 0.4 The *segmentation_model* parameter was renamed to *model*, *encoder_name* renamed to *backbone*, and *encoder_weights* to *weights*. """ super().__init__() # Creates `self.hparams` from kwargs self.save_hyperparameters() # type: ignore[operator] self.hyperparams = cast(Dict[str, Any], self.hparams) if not isinstance(kwargs["ignore_index"], (int, type(None))): raise ValueError("ignore_index must be an int or None") if (kwargs["ignore_index"] is not None) and (kwargs["loss"] == "jaccard"): warnings.warn( "ignore_index has no effect on training when loss='jaccard'", UserWarning, ) self.ignore_index = kwargs["ignore_index"] self.config_task() self.train_metrics = MetricCollection( [ MulticlassAccuracy( num_classes=self.hyperparams["num_classes"], ignore_index=self.ignore_index, mdmc_average="global", ), MulticlassJaccardIndex( num_classes=self.hyperparams["num_classes"], ignore_index=self.ignore_index, ), ], prefix="train_", ) self.val_metrics = self.train_metrics.clone(prefix="val_") self.test_metrics = self.train_metrics.clone(prefix="test_")
[docs] def forward(self, *args: Any, **kwargs: Any) -> Any: """Forward pass of the model. Args: x: tensor of data to run through the model Returns: output from the model """ return self.model(*args, **kwargs)
[docs] def training_step(self, *args: Any, **kwargs: Any) -> Tensor: """Compute and return the training loss. Args: batch: the output of your DataLoader Returns: training loss """ batch = args[0] x = batch["image"] y = batch["mask"] y_hat = self(x) y_hat_hard = y_hat.argmax(dim=1) loss = self.loss(y_hat, y) # by default, the train step logs every `log_every_n_steps` steps where # `log_every_n_steps` is a parameter to the `Trainer` object self.log("train_loss", loss, on_step=True, on_epoch=False) self.train_metrics(y_hat_hard, y) return cast(Tensor, loss)
[docs] def training_epoch_end(self, outputs: Any) -> None: """Logs epoch level training metrics. Args: outputs: list of items returned by training_step """ self.log_dict(self.train_metrics.compute()) self.train_metrics.reset()
[docs] def validation_step(self, *args: Any, **kwargs: Any) -> None: """Compute validation loss and log example predictions. Args: batch: the output of your DataLoader batch_idx: the index of this batch """ batch = args[0] batch_idx = args[1] x = batch["image"] y = batch["mask"] y_hat = self(x) y_hat_hard = y_hat.argmax(dim=1) loss = self.loss(y_hat, y) self.log("val_loss", loss, on_step=False, on_epoch=True) self.val_metrics(y_hat_hard, y) if ( batch_idx < 10 and hasattr(self.trainer, "datamodule") and self.logger and hasattr(self.logger, "experiment") ): try: datamodule = self.trainer.datamodule batch["prediction"] = y_hat_hard for key in ["image", "mask", "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] fig = datamodule.plot(sample) summary_writer = self.logger.experiment summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step ) plt.close() except ValueError: pass
[docs] def validation_epoch_end(self, outputs: Any) -> None: """Logs epoch level validation metrics. Args: outputs: list of items returned by validation_step """ self.log_dict(self.val_metrics.compute()) self.val_metrics.reset()
[docs] def test_step(self, *args: Any, **kwargs: Any) -> None: """Compute test loss. Args: batch: the output of your DataLoader """ batch = args[0] x = batch["image"] y = batch["mask"] y_hat = self(x) y_hat_hard = y_hat.argmax(dim=1) loss = self.loss(y_hat, y) # by default, the test and validation steps only log per *epoch* self.log("test_loss", loss, on_step=False, on_epoch=True) self.test_metrics(y_hat_hard, y)
[docs] def test_epoch_end(self, outputs: Any) -> None: """Logs epoch level test metrics. Args: outputs: list of items returned by test_step """ self.log_dict(self.test_metrics.compute()) self.test_metrics.reset()
[docs] def predict_step(self, *args: Any, **kwargs: Any) -> Tensor: """Compute and return the predictions. By default, this will loop over images in a dataloader and aggregate predictions into a list. This may not be desirable if you have many images or large images which could cause out of memory errors. In this case it's recommended to override this with a custom predict_step. Args: batch: the output of your DataLoader Returns: predicted softmax probabilities """ batch = args[0] x = batch["image"] y_hat: Tensor = self(x).softmax(dim=1) return y_hat
[docs] def configure_optimizers(self) -> Dict[str, Any]: """Initialize the optimizer and learning rate scheduler. Returns: a "lr dict" according to the pytorch lightning documentation -- https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers """ optimizer = torch.optim.Adam( self.model.parameters(), lr=self.hyperparams["learning_rate"] ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau( optimizer, patience=self.hyperparams["learning_rate_schedule_patience"], ), "monitor": "val_loss", }, }

© Copyright 2021, Microsoft Corporation. Revision 671737fd.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.1
v0.3.0
v0.2.1
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources