Source code for torchgeo.trainers.segmentation
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Segmentation tasks."""
from typing import Any, Dict, cast
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
from pytorch_lightning.core.lightning import LightningModule
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, JaccardIndex, MetricCollection
from ..datasets.utils import unbind_samples
from ..models import FCN
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"
class SemanticSegmentationTask(LightningModule):
"""LightningModule for semantic segmentation of images."""
[docs] def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor."""
if self.hyperparams["segmentation_model"] == "unet":
self.model = smp.Unet(
encoder_name=self.hyperparams["encoder_name"],
encoder_weights=self.hyperparams["encoder_weights"],
in_channels=self.hyperparams["in_channels"],
classes=self.hyperparams["num_classes"],
)
elif self.hyperparams["segmentation_model"] == "deeplabv3+":
self.model = smp.DeepLabV3Plus(
encoder_name=self.hyperparams["encoder_name"],
encoder_weights=self.hyperparams["encoder_weights"],
in_channels=self.hyperparams["in_channels"],
classes=self.hyperparams["num_classes"],
)
elif self.hyperparams["segmentation_model"] == "fcn":
self.model = FCN(
in_channels=self.hyperparams["in_channels"],
classes=self.hyperparams["num_classes"],
num_filters=self.hyperparams["num_filters"],
)
else:
raise ValueError(
f"Model type '{self.hyperparams['segmentation_model']}' is not valid."
)
if self.hyperparams["loss"] == "ce":
self.loss = nn.CrossEntropyLoss(
ignore_index=-1000 if self.ignore_zeros is None else 0
)
elif self.hyperparams["loss"] == "jaccard":
self.loss = smp.losses.JaccardLoss(
mode="multiclass", classes=self.hyperparams["num_classes"]
)
elif self.hyperparams["loss"] == "focal":
self.loss = smp.losses.FocalLoss(
"multiclass", ignore_index=self.ignore_zeros, normalized=True
)
else:
raise ValueError(f"Loss type '{self.hyperparams['loss']}' is not valid.")
[docs] def __init__(self, **kwargs: Any) -> None:
"""Initialize the LightningModule with a model and loss function.
Keyword Args:
segmentation_model: Name of the segmentation model type to use
encoder_name: Name of the encoder model backbone to use
encoder_weights: None or "imagenet" to use imagenet pretrained weights in
the encoder model
in_channels: Number of channels in input image
num_classes: Number of semantic classes to predict
loss: Name of the loss function
ignore_zeros: Whether to ignore the "0" class value in the loss and metrics
Raises:
ValueError: if kwargs arguments are invalid
"""
super().__init__()
# Creates `self.hparams` from kwargs
self.save_hyperparameters() # type: ignore[operator]
self.hyperparams = cast(Dict[str, Any], self.hparams)
self.ignore_zeros = None if kwargs["ignore_zeros"] else 0
self.config_task()
self.train_metrics = MetricCollection(
[
Accuracy(
num_classes=self.hyperparams["num_classes"],
ignore_index=self.ignore_zeros,
),
JaccardIndex(
num_classes=self.hyperparams["num_classes"],
ignore_index=self.ignore_zeros,
),
],
prefix="train_",
)
self.val_metrics = self.train_metrics.clone(prefix="val_")
self.test_metrics = self.train_metrics.clone(prefix="test_")
[docs] def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Forward pass of the model.
Args:
x: tensor of data to run through the model
Returns:
output from the model
"""
return self.model(*args, **kwargs)
[docs] def training_step(self, *args: Any, **kwargs: Any) -> Tensor:
"""Compute and return the training loss.
Args:
batch: the output of your DataLoader
Returns:
training loss
"""
batch = args[0]
x = batch["image"]
y = batch["mask"]
y_hat = self.forward(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.loss(y_hat, y)
# by default, the train step logs every `log_every_n_steps` steps where
# `log_every_n_steps` is a parameter to the `Trainer` object
self.log("train_loss", loss, on_step=True, on_epoch=False)
self.train_metrics(y_hat_hard, y)
return cast(Tensor, loss)
[docs] def training_epoch_end(self, outputs: Any) -> None:
"""Logs epoch level training metrics.
Args:
outputs: list of items returned by training_step
"""
self.log_dict(self.train_metrics.compute())
self.train_metrics.reset()
[docs] def validation_step(self, *args: Any, **kwargs: Any) -> None:
"""Compute validation loss and log example predictions.
Args:
batch: the output of your DataLoader
batch_idx: the index of this batch
"""
batch = args[0]
batch_idx = args[1]
x = batch["image"]
y = batch["mask"]
y_hat = self.forward(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.loss(y_hat, y)
self.log("val_loss", loss, on_step=False, on_epoch=True)
self.val_metrics(y_hat_hard, y)
if batch_idx < 10:
try:
datamodule = self.trainer.datamodule # type: ignore[attr-defined]
batch["prediction"] = y_hat_hard
for key in ["image", "mask", "prediction"]:
batch[key] = batch[key].cpu()
sample = unbind_samples(batch)[0]
fig = datamodule.plot(sample)
summary_writer = self.logger.experiment
summary_writer.add_figure(
f"image/{batch_idx}", fig, global_step=self.global_step
)
except AttributeError:
pass
[docs] def validation_epoch_end(self, outputs: Any) -> None:
"""Logs epoch level validation metrics.
Args:
outputs: list of items returned by validation_step
"""
self.log_dict(self.val_metrics.compute())
self.val_metrics.reset()
[docs] def test_step(self, *args: Any, **kwargs: Any) -> None:
"""Compute test loss.
Args:
batch: the output of your DataLoader
"""
batch = args[0]
x = batch["image"]
y = batch["mask"]
y_hat = self.forward(x)
y_hat_hard = y_hat.argmax(dim=1)
loss = self.loss(y_hat, y)
# by default, the test and validation steps only log per *epoch*
self.log("test_loss", loss, on_step=False, on_epoch=True)
self.test_metrics(y_hat_hard, y)
[docs] def test_epoch_end(self, outputs: Any) -> None:
"""Logs epoch level test metrics.
Args:
outputs: list of items returned by test_step
"""
self.log_dict(self.test_metrics.compute())
self.test_metrics.reset()
[docs] def configure_optimizers(self) -> Dict[str, Any]:
"""Initialize the optimizer and learning rate scheduler.
Returns:
a "lr dict" according to the pytorch lightning documentation --
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
"""
optimizer = torch.optim.Adam(
self.model.parameters(), lr=self.hyperparams["learning_rate"]
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": ReduceLROnPlateau(
optimizer,
patience=self.hyperparams["learning_rate_schedule_patience"],
),
"monitor": "val_loss",
},
}