Source code for torchgeo.trainers.classification
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Trainers for image classification."""
import os
from typing import Any
import matplotlib.pyplot as plt
import timm
import torch
import torch.nn as nn
from matplotlib.figure import Figure
from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss
from torch import Tensor
from torchmetrics import MetricCollection
from torchmetrics.classification import (
MulticlassAccuracy,
MulticlassFBetaScore,
MulticlassJaccardIndex,
MultilabelAccuracy,
MultilabelFBetaScore,
)
from torchvision.models._api import WeightsEnum
from ..datasets import RGBBandsMissingError, unbind_samples
from ..models import get_weight
from . import utils
from .base import BaseTask
[docs]class ClassificationTask(BaseTask):
"""Image classification."""
[docs] def __init__(
self,
model: str = "resnet50",
weights: WeightsEnum | str | bool | None = None,
in_channels: int = 3,
num_classes: int = 1000,
loss: str = "ce",
class_weights: Tensor | None = None,
lr: float = 1e-3,
patience: int = 10,
freeze_backbone: bool = False,
) -> None:
"""Initialize a new ClassificationTask instance.
Args:
model: Name of the `timm
<https://huggingface.co/docs/timm/reference/models>`__ model to use.
weights: Initial model weights. Either a weight enum, the string
representation of a weight enum, True for ImageNet weights, False
or None for random weights, or the path to a saved model state dict.
in_channels: Number of input channels to model.
num_classes: Number of prediction classes.
loss: One of 'ce', 'bce', 'jaccard', or 'focal'.
class_weights: Optional rescaling weight given to each
class and used with 'ce' loss.
lr: Learning rate for optimizer.
patience: Patience for learning rate scheduler.
freeze_backbone: Freeze the backbone network to linear probe
the classifier head.
.. versionchanged:: 0.4
*classification_model* was renamed to *model*.
.. versionadded:: 0.5
The *class_weights* and *freeze_backbone* parameters.
.. versionchanged:: 0.5
*learning_rate* and *learning_rate_schedule_patience* were renamed to
*lr* and *patience*.
"""
self.weights = weights
super().__init__(ignore="weights")
[docs] def configure_models(self) -> None:
"""Initialize the model."""
weights = self.weights
# Create model
self.model = timm.create_model(
self.hparams["model"],
num_classes=self.hparams["num_classes"],
in_chans=self.hparams["in_channels"],
pretrained=weights is True,
)
# Load weights
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
utils.load_state_dict(self.model, state_dict)
# Freeze backbone and unfreeze classifier head
if self.hparams["freeze_backbone"]:
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.get_classifier().parameters():
param.requires_grad = True
[docs] def configure_losses(self) -> None:
"""Initialize the loss criterion.
Raises:
ValueError: If *loss* is invalid.
"""
loss: str = self.hparams["loss"]
if loss == "ce":
self.criterion: nn.Module = nn.CrossEntropyLoss(
weight=self.hparams["class_weights"]
)
elif loss == "bce":
self.criterion = nn.BCEWithLogitsLoss()
elif loss == "jaccard":
self.criterion = JaccardLoss(mode="multiclass")
elif loss == "focal":
self.criterion = FocalLoss(mode="multiclass", normalized=True)
else:
raise ValueError(f"Loss type '{loss}' is not valid.")
[docs] def configure_metrics(self) -> None:
"""Initialize the performance metrics.
* :class:`~torchmetrics.classification.MulticlassAccuracy`: The number of
true positives divided by the dataset size. Both overall accuracy (OA)
using 'micro' averaging and average accuracy (AA) using 'macro' averaging
are reported. Higher values are better.
* :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection
over union (IoU). Uses 'macro' averaging. Higher valuers are better.
* :class:`~torchmetrics.classification.MulticlassFBetaScore`: F1 score.
The harmonic mean of precision and recall. Uses 'micro' averaging.
Higher values are better.
.. note::
* 'Micro' averaging suits overall performance evaluation but may not reflect
minority class accuracy.
* 'Macro' averaging gives equal weight to each class, and is useful for
balanced performance assessment across imbalanced classes.
"""
metrics = MetricCollection(
{
"OverallAccuracy": MulticlassAccuracy(
num_classes=self.hparams["num_classes"], average="micro"
),
"AverageAccuracy": MulticlassAccuracy(
num_classes=self.hparams["num_classes"], average="macro"
),
"JaccardIndex": MulticlassJaccardIndex(
num_classes=self.hparams["num_classes"]
),
"F1Score": MulticlassFBetaScore(
num_classes=self.hparams["num_classes"], beta=1.0, average="micro"
),
}
)
self.train_metrics = metrics.clone(prefix="train_")
self.val_metrics = metrics.clone(prefix="val_")
self.test_metrics = metrics.clone(prefix="test_")
[docs] def training_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute the training loss and additional metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
Returns:
The loss tensor.
"""
x = batch["image"]
y = batch["label"]
batch_size = x.shape[0]
y_hat = self(x)
loss: Tensor = self.criterion(y_hat, y)
self.log("train_loss", loss, batch_size=batch_size)
self.train_metrics(y_hat, y)
self.log_dict(self.train_metrics, batch_size=batch_size)
return loss
[docs] def validation_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Compute the validation loss and additional metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
"""
x = batch["image"]
y = batch["label"]
batch_size = x.shape[0]
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log("val_loss", loss, batch_size=batch_size)
self.val_metrics(y_hat, y)
self.log_dict(self.val_metrics, batch_size=batch_size)
if (
batch_idx < 10
and hasattr(self.trainer, "datamodule")
and hasattr(self.trainer.datamodule, "plot")
and self.logger
and hasattr(self.logger, "experiment")
and hasattr(self.logger.experiment, "add_figure")
):
datamodule = self.trainer.datamodule
batch["prediction"] = y_hat.argmax(dim=-1)
for key in ["image", "label", "prediction"]:
batch[key] = batch[key].cpu()
sample = unbind_samples(batch)[0]
fig: Figure | None = None
try:
fig = datamodule.plot(sample)
except RGBBandsMissingError:
pass
if fig:
summary_writer = self.logger.experiment
summary_writer.add_figure(
f"image/{batch_idx}", fig, global_step=self.global_step
)
plt.close()
[docs] def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""Compute the test loss and additional metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
"""
x = batch["image"]
y = batch["label"]
batch_size = x.shape[0]
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log("test_loss", loss, batch_size=batch_size)
self.test_metrics(y_hat, y)
self.log_dict(self.test_metrics, batch_size=batch_size)
[docs] def predict_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute the predicted class probabilities.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
Returns:
Output predicted probabilities.
"""
x = batch["image"]
y_hat: Tensor = self(x).softmax(dim=-1)
return y_hat
[docs]class MultiLabelClassificationTask(ClassificationTask):
"""Multi-label image classification."""
[docs] def configure_metrics(self) -> None:
"""Initialize the performance metrics.
* :class:`~torchmetrics.classification.MultilabelAccuracy`: The number of
true positives divided by the dataset size. Both overall accuracy (OA)
using 'micro' averaging and average accuracy (AA) using 'macro' averaging
are reported. Higher values are better.
* :class:`~torchmetrics.classification.MultilabelFBetaScore`: F1 score.
The harmonic mean of precision and recall. Uses 'micro' averaging.
Higher values are better.
.. note::
* 'Micro' averaging suits overall performance evaluation but may not
reflect minority class accuracy.
* 'Macro' averaging gives equal weight to each class, and is useful for
balanced performance assessment across imbalanced classes.
"""
metrics = MetricCollection(
{
"OverallAccuracy": MultilabelAccuracy(
num_labels=self.hparams["num_classes"], average="micro"
),
"AverageAccuracy": MultilabelAccuracy(
num_labels=self.hparams["num_classes"], average="macro"
),
"F1Score": MultilabelFBetaScore(
num_labels=self.hparams["num_classes"], beta=1.0, average="micro"
),
}
)
self.train_metrics = metrics.clone(prefix="train_")
self.val_metrics = metrics.clone(prefix="val_")
self.test_metrics = metrics.clone(prefix="test_")
[docs] def training_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute the training loss and additional metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
Returns:
The loss tensor.
"""
x = batch["image"]
y = batch["label"]
batch_size = x.shape[0]
y_hat = self(x)
y_hat_hard = torch.sigmoid(y_hat)
loss: Tensor = self.criterion(y_hat, y.to(torch.float))
self.log("train_loss", loss, batch_size=batch_size)
self.train_metrics(y_hat_hard, y)
self.log_dict(self.train_metrics)
return loss
[docs] def validation_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Compute the validation loss and additional metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
"""
x = batch["image"]
y = batch["label"]
batch_size = x.shape[0]
y_hat = self(x)
y_hat_hard = torch.sigmoid(y_hat)
loss = self.criterion(y_hat, y.to(torch.float))
self.log("val_loss", loss, batch_size=batch_size)
self.val_metrics(y_hat_hard, y)
self.log_dict(self.val_metrics, batch_size=batch_size)
if (
batch_idx < 10
and hasattr(self.trainer, "datamodule")
and hasattr(self.trainer.datamodule, "plot")
and self.logger
and hasattr(self.logger, "experiment")
and hasattr(self.logger.experiment, "add_figure")
):
datamodule = self.trainer.datamodule
batch["prediction"] = y_hat_hard
for key in ["image", "label", "prediction"]:
batch[key] = batch[key].cpu()
sample = unbind_samples(batch)[0]
fig: Figure | None = None
try:
fig = datamodule.plot(sample)
except RGBBandsMissingError:
pass
if fig:
summary_writer = self.logger.experiment
summary_writer.add_figure(
f"image/{batch_idx}", fig, global_step=self.global_step
)
[docs] def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""Compute the test loss and additional metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
"""
x = batch["image"]
y = batch["label"]
batch_size = x.shape[0]
y_hat = self(x)
y_hat_hard = torch.sigmoid(y_hat)
loss = self.criterion(y_hat, y.to(torch.float))
self.log("test_loss", loss, batch_size=batch_size)
self.test_metrics(y_hat_hard, y)
self.log_dict(self.test_metrics, batch_size=batch_size)
[docs] def predict_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute the predicted class probabilities.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
Returns:
Output predicted probabilities.
"""
x = batch["image"]
y_hat = torch.sigmoid(self(x))
return y_hat