Shortcuts

Source code for torchgeo.trainers.detection

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Detection tasks."""

from functools import partial
from typing import Any, Dict, List, cast

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision.models import resnet as R
from torchvision.models.detection import FCOS, FasterRCNN, RetinaNet
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.retinanet import RetinaNetHead
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign, feature_pyramid_network, misc

from ..datasets.utils import unbind_samples

BACKBONE_LAT_DIM_MAP = {
    "resnet18": 512,
    "resnet34": 512,
    "resnet50": 2048,
    "resnet101": 2048,
    "resnet152": 2048,
    "resnext50_32x4d": 2048,
    "resnext101_32x8d": 2048,
    "wide_resnet50_2": 2048,
    "wide_resnet101_2": 2048,
}

BACKBONE_WEIGHT_MAP = {
    "resnet18": R.ResNet18_Weights.DEFAULT,
    "resnet34": R.ResNet34_Weights.DEFAULT,
    "resnet50": R.ResNet50_Weights.DEFAULT,
    "resnet101": R.ResNet101_Weights.DEFAULT,
    "resnet152": R.ResNet152_Weights.DEFAULT,
    "resnext50_32x4d": R.ResNeXt50_32X4D_Weights.DEFAULT,
    "resnext101_32x8d": R.ResNeXt101_32X8D_Weights.DEFAULT,
    "wide_resnet50_2": R.Wide_ResNet50_2_Weights.DEFAULT,
    "wide_resnet101_2": R.Wide_ResNet101_2_Weights.DEFAULT,
}


[docs]class ObjectDetectionTask(pl.LightningModule): """LightningModule for object detection of images. Currently, supports Faster R-CNN, FCOS, and RetinaNet models from `torchvision <https://pytorch.org/vision/stable/models.html #object-detection-instance-segmentation-and-person-keypoint-detection>`_ with one of the following *backbone* arguments: .. code-block:: python ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d','resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'] .. versionadded:: 0.4 """
[docs] def config_task(self) -> None: """Configures the task based on kwargs parameters passed to the constructor.""" backbone_pretrained = self.hyperparams.get("pretrained", True) if self.hyperparams["backbone"] in BACKBONE_LAT_DIM_MAP: kwargs = { "backbone_name": self.hyperparams["backbone"], "trainable_layers": self.hyperparams.get("trainable_layers", 3), } if backbone_pretrained: kwargs["weights"] = BACKBONE_WEIGHT_MAP[self.hyperparams["backbone"]] else: kwargs["weights"] = None latent_dim = BACKBONE_LAT_DIM_MAP[self.hyperparams["backbone"]] else: raise ValueError( f"Backbone type '{self.hyperparams['backbone']}' is not valid." ) num_classes = self.hyperparams["num_classes"] if self.hyperparams["model"] == "faster-rcnn": backbone = resnet_fpn_backbone(**kwargs) anchor_generator = AnchorGenerator( sizes=((32), (64), (128), (256), (512)), aspect_ratios=((0.5, 1.0, 2.0)) ) roi_pooler = MultiScaleRoIAlign( featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2 ) self.model = FasterRCNN( backbone, num_classes, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, ) elif self.hyperparams["model"] == "fcos": kwargs["extra_blocks"] = feature_pyramid_network.LastLevelP6P7(256, 256) kwargs["norm_layer"] = ( misc.FrozenBatchNorm2d if kwargs["weights"] else torch.nn.BatchNorm2d ) backbone = resnet_fpn_backbone(**kwargs) anchor_generator = AnchorGenerator( sizes=((8,), (16,), (32,), (64,), (128,), (256,)), aspect_ratios=((1.0,), (1.0,), (1.0,), (1.0,), (1.0,), (1.0,)), ) self.model = FCOS(backbone, num_classes, anchor_generator=anchor_generator) elif self.hyperparams["model"] == "retinanet": kwargs["extra_blocks"] = feature_pyramid_network.LastLevelP6P7( latent_dim, 256 ) backbone = resnet_fpn_backbone(**kwargs) anchor_sizes = ( (16, 20, 25), (32, 40, 50), (64, 80, 101), (128, 161, 203), (256, 322, 406), (512, 645, 812), ) aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) head = RetinaNetHead( backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes, norm_layer=partial(torch.nn.GroupNorm, 32), ) self.model = RetinaNet( backbone, num_classes, anchor_generator=anchor_generator, head=head ) else: raise ValueError(f"Model type '{self.hyperparams['model']}' is not valid.")
[docs] def __init__(self, **kwargs: Any) -> None: """Initialize the LightningModule with a model and loss function. Keyword Args: model: Name of the detection model type to use backbone: Name of the model backbone to use in_channels: Number of channels in input image num_classes: Number of semantic classes to predict learning_rate: Learning rate for optimizer learning_rate_schedule_patience: Patience for learning rate scheduler Raises: ValueError: if kwargs arguments are invalid .. versionchanged:: 0.4 The *detection_model* parameter was renamed to *model*. """ super().__init__() # Creates `self.hparams` from kwargs self.save_hyperparameters() # type: ignore[operator] self.hyperparams = cast(Dict[str, Any], self.hparams) self.config_task() self.val_metrics = MeanAveragePrecision() self.test_metrics = MeanAveragePrecision()
[docs] def forward(self, *args: Any, **kwargs: Any) -> Any: """Forward pass of the model. Args: x: tensor of data to run through the model Returns: output from the model """ return self.model(*args, **kwargs)
[docs] def training_step(self, *args: Any, **kwargs: Any) -> Tensor: """Compute and return the training loss. Args: batch: the output of your DataLoader Returns: training loss """ batch = args[0] x = batch["image"] batch_size = x.shape[0] y = [ {"boxes": batch["boxes"][i], "labels": batch["labels"][i]} for i in range(batch_size) ] loss_dict = self(x, y) train_loss = sum(loss_dict.values()) self.log_dict(loss_dict) return cast(Tensor, train_loss)
[docs] def validation_step(self, *args: Any, **kwargs: Any) -> None: """Compute validation loss and log example predictions. Args: batch: the output of your DataLoader batch_idx: the index of this batch """ batch = args[0] batch_idx = args[1] x = batch["image"] batch_size = x.shape[0] y = [ {"boxes": batch["boxes"][i], "labels": batch["labels"][i]} for i in range(batch_size) ] y_hat = self(x) self.val_metrics.update(y_hat, y) if ( batch_idx < 10 and hasattr(self.trainer, "datamodule") and self.logger and hasattr(self.logger, "experiment") ): try: datamodule = self.trainer.datamodule batch["prediction_boxes"] = [b["boxes"].cpu() for b in y_hat] batch["prediction_labels"] = [b["labels"].cpu() for b in y_hat] batch["prediction_scores"] = [b["scores"].cpu() for b in y_hat] batch["image"] = batch["image"].cpu() sample = unbind_samples(batch)[0] # Convert image to uint8 for plotting if torch.is_floating_point(sample["image"]): sample["image"] *= 255 sample["image"] = sample["image"].to(torch.uint8) fig = datamodule.plot(sample) summary_writer = self.logger.experiment summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step ) plt.close() except ValueError: pass
[docs] def validation_epoch_end(self, outputs: Any) -> None: """Logs epoch level validation metrics. Args: outputs: list of items returned by validation_step """ metrics = self.val_metrics.compute() renamed_metrics = {f"val_{i}": metrics[i] for i in metrics.keys()} self.log_dict(renamed_metrics) self.val_metrics.reset()
[docs] def test_step(self, *args: Any, **kwargs: Any) -> None: """Compute test MAP. Args: batch: the output of your DataLoader """ batch = args[0] x = batch["image"] batch_size = x.shape[0] y = [ {"boxes": batch["boxes"][i], "labels": batch["labels"][i]} for i in range(batch_size) ] y_hat = self(x) self.test_metrics.update(y_hat, y)
[docs] def test_epoch_end(self, outputs: Any) -> None: """Logs epoch level test metrics. Args: outputs: list of items returned by test_step """ metrics = self.test_metrics.compute() renamed_metrics = {f"test_{i}": metrics[i] for i in metrics.keys()} self.log_dict(renamed_metrics) self.test_metrics.reset()
[docs] def predict_step(self, *args: Any, **kwargs: Any) -> List[Dict[str, Tensor]]: """Compute and return the predictions. Args: batch: the output of your DataLoader Returns: list of predicted boxes, labels and scores """ batch = args[0] x = batch["image"] y_hat: List[Dict[str, Tensor]] = self(x) return y_hat
[docs] def configure_optimizers(self) -> Dict[str, Any]: """Initialize the optimizer and learning rate scheduler. Returns: a "lr dict" according to the pytorch lightning documentation -- https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers """ optimizer = torch.optim.Adam( self.model.parameters(), lr=self.hyperparams["learning_rate"] ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau( optimizer, mode="max", patience=self.hyperparams["learning_rate_schedule_patience"], ), "monitor": "val_map", }, }

© Copyright 2021, Microsoft Corporation. Revision 671737fd.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.1
v0.3.0
v0.2.1
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources