Shortcuts

Source code for torchgeo.losses.qr

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Loss functions for learing on the prior."""

import torch
import torch.nn.functional as F
from torch.nn.modules import Module


[docs]class QRLoss(Module): """The QR (forward) loss between class probabilities and predictions. This loss is defined in `'Resolving label uncertainty with implicit generative models' <https://openreview.net/forum?id=AEa_UepnMDX>`_. .. versionadded:: 0.2 """
[docs] def forward(self, probs: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Computes the QR (forwards) loss on prior. Args: probs: probabilities of predictions, expected shape B x C x H x W. target: prior probabilities, expected shape B x C x H x W. Returns: qr loss """ q = probs q_bar = q.mean(dim=(0, 2, 3)) qbar_log_S = (q_bar * torch.log(q_bar)).sum() q_log_p = torch.einsum('bcxy,bcxy->bxy', q, torch.log(target)).mean() loss = qbar_log_S - q_log_p return loss
[docs]class RQLoss(Module): """The RQ (backwards) loss between class probabilities and predictions. This loss is defined in `'Resolving label uncertainty with implicit generative models' <https://openreview.net/forum?id=AEa_UepnMDX>`_. .. versionadded:: 0.2 """
[docs] def forward(self, probs: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Computes the RQ (backwards) loss on prior. Args: probs: probabilities of predictions, expected shape B x C x H x W target: prior probabilities, expected shape B x C x H x W Returns: qr loss """ q = probs # manually normalize due to https://github.com/pytorch/pytorch/issues/70100 z = q / q.norm(p=1, dim=(0, 2, 3), keepdim=True).clamp_min(1e-12).expand_as(q) r = F.normalize(z * target, p=1, dim=1) loss = torch.einsum('bcxy,bcxy->bxy', r, torch.log(r) - torch.log(q)).mean() return loss

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources