Shortcuts

Source code for torchgeo.trainers.segmentation

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Trainers for semantic segmentation."""

import os
from typing import Any

import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
import torch.nn as nn
from matplotlib.figure import Figure
from torch import Tensor
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex
from torchvision.models._api import WeightsEnum

from ..datasets import RGBBandsMissingError, unbind_samples
from ..models import FCN, get_weight
from . import utils
from .base import BaseTask


[docs]class SemanticSegmentationTask(BaseTask): """Semantic Segmentation."""
[docs] def __init__( self, model: str = "unet", backbone: str = "resnet50", weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, num_classes: int = 1000, num_filters: int = 3, loss: str = "ce", class_weights: Tensor | None = None, ignore_index: int | None = None, lr: float = 1e-3, patience: int = 10, freeze_backbone: bool = False, freeze_decoder: bool = False, ) -> None: """Initialize a new SemanticSegmentationTask instance. Args: model: Name of the `smp <https://smp.readthedocs.io/en/latest/models.html>`__ model to use. backbone: Name of the `timm <https://smp.readthedocs.io/en/latest/encoders_timm.html>`__ or `smp <https://smp.readthedocs.io/en/latest/encoders.html>`__ backbone to use. weights: Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict. FCN model does not support pretrained weights. Pretrained ViT weight enums are not supported yet. in_channels: Number of input channels to model. num_classes: Number of prediction classes. num_filters: Number of filters. Only applicable when model='fcn'. loss: Name of the loss function, currently supports 'ce', 'jaccard' or 'focal' loss. class_weights: Optional rescaling weight given to each class and used with 'ce' loss. ignore_index: Optional integer class index to ignore in the loss and metrics. lr: Learning rate for optimizer. patience: Patience for learning rate scheduler. freeze_backbone: Freeze the backbone network to fine-tune the decoder and segmentation head. freeze_decoder: Freeze the decoder network to linear probe the segmentation head. .. versionchanged:: 0.3 *ignore_zeros* was renamed to *ignore_index*. .. versionchanged:: 0.4 *segmentation_model*, *encoder_name*, and *encoder_weights* were renamed to *model*, *backbone*, and *weights*. .. versionadded:: 0.5 The *class_weights*, *freeze_backbone*, and *freeze_decoder* parameters. .. versionchanged:: 0.5 The *weights* parameter now supports WeightEnums and checkpoint paths. *learning_rate* and *learning_rate_schedule_patience* were renamed to *lr* and *patience*. .. versionchanged:: 0.6 The *ignore_index* parameter now works for jaccard loss. """ self.weights = weights super().__init__(ignore="weights")
[docs] def configure_models(self) -> None: """Initialize the model. Raises: ValueError: If *model* is invalid. """ model: str = self.hparams["model"] backbone: str = self.hparams["backbone"] weights = self.weights in_channels: int = self.hparams["in_channels"] num_classes: int = self.hparams["num_classes"] num_filters: int = self.hparams["num_filters"] if model == "unet": self.model = smp.Unet( encoder_name=backbone, encoder_weights="imagenet" if weights is True else None, in_channels=in_channels, classes=num_classes, ) elif model == "deeplabv3+": self.model = smp.DeepLabV3Plus( encoder_name=backbone, encoder_weights="imagenet" if weights is True else None, in_channels=in_channels, classes=num_classes, ) elif model == "fcn": self.model = FCN( in_channels=in_channels, classes=num_classes, num_filters=num_filters ) else: raise ValueError( f"Model type '{model}' is not valid. " "Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." ) if model != "fcn": if weights and weights is not True: if isinstance(weights, WeightsEnum): state_dict = weights.get_state_dict(progress=True) elif os.path.exists(weights): _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) self.model.encoder.load_state_dict(state_dict) # Freeze backbone if self.hparams["freeze_backbone"] and model in ["unet", "deeplabv3+"]: for param in self.model.encoder.parameters(): param.requires_grad = False # Freeze decoder if self.hparams["freeze_decoder"] and model in ["unet", "deeplabv3+"]: for param in self.model.decoder.parameters(): param.requires_grad = False
[docs] def configure_losses(self) -> None: """Initialize the loss criterion. Raises: ValueError: If *loss* is invalid. """ loss: str = self.hparams["loss"] ignore_index = self.hparams["ignore_index"] if loss == "ce": ignore_value = -1000 if ignore_index is None else ignore_index self.criterion = nn.CrossEntropyLoss( ignore_index=ignore_value, weight=self.hparams["class_weights"] ) elif loss == "jaccard": # JaccardLoss requires a list of classes to use instead of a class # index to ignore. classes = [ i for i in range(self.hparams["num_classes"]) if i != ignore_index ] self.criterion = smp.losses.JaccardLoss(mode="multiclass", classes=classes) elif loss == "focal": self.criterion = smp.losses.FocalLoss( "multiclass", ignore_index=ignore_index, normalized=True ) else: raise ValueError( f"Loss type '{loss}' is not valid. " "Currently, supports 'ce', 'jaccard' or 'focal' loss." )
[docs] def configure_metrics(self) -> None: """Initialize the performance metrics. * :class:`~torchmetrics.classification.MulticlassAccuracy`: Overall accuracy (OA) using 'micro' averaging. The number of true positives divided by the dataset size. Higher values are better. * :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection over union (IoU). Uses 'micro' averaging. Higher valuers are better. .. note:: * 'Micro' averaging suits overall performance evaluation but may not reflect minority class accuracy. * 'Macro' averaging, not used here, gives equal weight to each class, useful for balanced performance assessment across imbalanced classes. """ num_classes: int = self.hparams["num_classes"] ignore_index: int | None = self.hparams["ignore_index"] metrics = MetricCollection( [ MulticlassAccuracy( num_classes=num_classes, ignore_index=ignore_index, multidim_average="global", average="micro", ), MulticlassJaccardIndex( num_classes=num_classes, ignore_index=ignore_index, average="micro" ), ] ) self.train_metrics = metrics.clone(prefix="train_") self.val_metrics = metrics.clone(prefix="val_") self.test_metrics = metrics.clone(prefix="test_")
[docs] def training_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute the training loss and additional metrics. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. Returns: The loss tensor. """ x = batch["image"] y = batch["mask"] batch_size = x.shape[0] y_hat = self(x) loss: Tensor = self.criterion(y_hat, y) self.log("train_loss", loss, batch_size=batch_size) self.train_metrics(y_hat, y) self.log_dict(self.train_metrics, batch_size=batch_size) return loss
[docs] def validation_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> None: """Compute the validation loss and additional metrics. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ x = batch["image"] y = batch["mask"] batch_size = x.shape[0] y_hat = self(x) loss = self.criterion(y_hat, y) self.log("val_loss", loss, batch_size=batch_size) self.val_metrics(y_hat, y) self.log_dict(self.val_metrics, batch_size=batch_size) if ( batch_idx < 10 and hasattr(self.trainer, "datamodule") and hasattr(self.trainer.datamodule, "plot") and self.logger and hasattr(self.logger, "experiment") and hasattr(self.logger.experiment, "add_figure") ): datamodule = self.trainer.datamodule batch["prediction"] = y_hat.argmax(dim=1) for key in ["image", "mask", "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] fig: Figure | None = None try: fig = datamodule.plot(sample) except RGBBandsMissingError: pass if fig: summary_writer = self.logger.experiment summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step ) plt.close()
[docs] def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the test loss and additional metrics. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ x = batch["image"] y = batch["mask"] batch_size = x.shape[0] y_hat = self(x) loss = self.criterion(y_hat, y) self.log("test_loss", loss, batch_size=batch_size) self.test_metrics(y_hat, y) self.log_dict(self.test_metrics, batch_size=batch_size)
[docs] def predict_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute the predicted class probabilities. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. Returns: Output predicted probabilities. """ x = batch["image"] y_hat: Tensor = self(x).softmax(dim=1) return y_hat

© Copyright 2021, Microsoft Corporation. Revision df9662a4.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
v0.5.2
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.1
v0.3.0
v0.2.1
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources