Shortcuts

Source code for torchgeo.trainers.segmentation

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Segmentation tasks."""

from typing import Any, Dict, cast

import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
from pytorch_lightning.core.lightning import LightningModule
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, IoU, MetricCollection

from ..datasets.utils import unbind_samples
from ..models import FCN

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class SemanticSegmentationTask(LightningModule):
    """LightningModule for semantic segmentation of images."""

[docs] def config_task(self) -> None: """Configures the task based on kwargs parameters passed to the constructor.""" if self.hparams["segmentation_model"] == "unet": self.model = smp.Unet( encoder_name=self.hparams["encoder_name"], encoder_weights=self.hparams["encoder_weights"], in_channels=self.hparams["in_channels"], classes=self.hparams["num_classes"], ) elif self.hparams["segmentation_model"] == "deeplabv3+": self.model = smp.DeepLabV3Plus( encoder_name=self.hparams["encoder_name"], encoder_weights=self.hparams["encoder_weights"], in_channels=self.hparams["in_channels"], classes=self.hparams["num_classes"], ) elif self.hparams["segmentation_model"] == "fcn": self.model = FCN( in_channels=self.hparams["in_channels"], classes=self.hparams["num_classes"], num_filters=self.hparams["num_filters"], ) else: raise ValueError( f"Model type '{self.hparams['segmentation_model']}' is not valid." ) if self.hparams["loss"] == "ce": self.loss = nn.CrossEntropyLoss( # type: ignore[attr-defined] ignore_index=-1000 if self.ignore_zeros is None else 0 ) elif self.hparams["loss"] == "jaccard": self.loss = smp.losses.JaccardLoss( mode="multiclass", classes=self.hparams["num_classes"] ) elif self.hparams["loss"] == "focal": self.loss = smp.losses.FocalLoss( "multiclass", ignore_index=self.ignore_zeros, normalized=True ) else: raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.")
[docs] def __init__(self, **kwargs: Any) -> None: """Initialize the LightningModule with a model and loss function. Keyword Args: segmentation_model: Name of the segmentation model type to use encoder_name: Name of the encoder model backbone to use encoder_weights: None or "imagenet" to use imagenet pretrained weights in the encoder model in_channels: Number of channels in input image num_classes: Number of semantic classes to predict loss: Name of the loss function ignore_zeros: Whether to ignore the "0" class value in the loss and metrics Raises: ValueError: if kwargs arguments are invalid """ super().__init__() self.save_hyperparameters() # creates `self.hparams` from kwargs self.ignore_zeros = None if kwargs["ignore_zeros"] else 0 self.config_task() self.train_metrics = MetricCollection( [ Accuracy( num_classes=self.hparams["num_classes"], ignore_index=self.ignore_zeros, ), IoU( num_classes=self.hparams["num_classes"], ignore_index=self.ignore_zeros, ), ], prefix="train_", ) self.val_metrics = self.train_metrics.clone(prefix="val_") self.test_metrics = self.train_metrics.clone(prefix="test_")
[docs] def forward(self, x: Tensor) -> Any: # type: ignore[override] """Forward pass of the model. Args: x: tensor of data to run through the model Returns: output from the model """ return self.model(x)
[docs] def training_step( # type: ignore[override] self, batch: Dict[str, Any], batch_idx: int ) -> Tensor: """Training step - reports average accuracy and average IoU. Args: batch: Current batch batch_idx: Index of current batch Returns: training loss """ x = batch["image"] y = batch["mask"] y_hat = self.forward(x) y_hat_hard = y_hat.argmax(dim=1) loss = self.loss(y_hat, y) # by default, the train step logs every `log_every_n_steps` steps where # `log_every_n_steps` is a parameter to the `Trainer` object self.log("train_loss", loss, on_step=True, on_epoch=False) self.train_metrics(y_hat_hard, y) return cast(Tensor, loss)
[docs] def training_epoch_end(self, outputs: Any) -> None: """Logs epoch level training metrics. Args: outputs: list of items returned by training_step """ self.log_dict(self.train_metrics.compute()) self.train_metrics.reset()
[docs] def validation_step( # type: ignore[override] self, batch: Dict[str, Any], batch_idx: int ) -> None: """Validation step - reports average accuracy and average IoU. Logs the first 10 validation samples to tensorboard as images with 3 subplots showing the image, mask, and predictions. Args: batch: Current batch batch_idx: Index of current batch """ x = batch["image"] y = batch["mask"] y_hat = self.forward(x) y_hat_hard = y_hat.argmax(dim=1) loss = self.loss(y_hat, y) self.log("val_loss", loss, on_step=False, on_epoch=True) self.val_metrics(y_hat_hard, y) if batch_idx < 10: try: datamodule = self.trainer.datamodule # type: ignore[attr-defined] batch["prediction"] = y_hat_hard for key in ["image", "mask", "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] fig = datamodule.plot(sample) summary_writer = self.logger.experiment summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step ) except AttributeError: pass
[docs] def validation_epoch_end(self, outputs: Any) -> None: """Logs epoch level validation metrics. Args: outputs: list of items returned by validation_step """ self.log_dict(self.val_metrics.compute()) self.val_metrics.reset()
[docs] def test_step( # type: ignore[override] self, batch: Dict[str, Any], batch_idx: int ) -> None: """Test step identical to the validation step. Args: batch: Current batch batch_idx: Index of current batch """ x = batch["image"] y = batch["mask"] y_hat = self.forward(x) y_hat_hard = y_hat.argmax(dim=1) loss = self.loss(y_hat, y) # by default, the test and validation steps only log per *epoch* self.log("test_loss", loss, on_step=False, on_epoch=True) self.test_metrics(y_hat_hard, y)
[docs] def test_epoch_end(self, outputs: Any) -> None: """Logs epoch level test metrics. Args: outputs: list of items returned by test_step """ self.log_dict(self.test_metrics.compute()) self.test_metrics.reset()
[docs] def configure_optimizers(self) -> Dict[str, Any]: """Initialize the optimizer and learning rate scheduler. Returns: a "lr dict" according to the pytorch lightning documentation -- https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers """ optimizer = torch.optim.Adam( self.model.parameters(), lr=self.hparams["learning_rate"] ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau( optimizer, patience=self.hparams["learning_rate_schedule_patience"] ), "monitor": "val_loss", }, }

© Copyright 2021, Microsoft Corporation. Revision e1285e6c.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources