# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""BYOL tasks."""

import random
from typing import Any, Callable, Dict, Optional, Tuple, cast

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchvision
from kornia import augmentation as K
from kornia import filters
from kornia.geometry import transform as KorniaTransform
from packaging.version import parse
from torch import Tensor, optim
from torch.autograd import Variable
from torch.nn.modules import BatchNorm1d, Conv2d, Linear, Module, ReLU, Sequential
from torch.optim.lr_scheduler import ReduceLROnPlateau

Module.__module__ = "torch.nn"

def normalized_mse(x: Tensor, y: Tensor) -> Tensor:
    """Computes the normalized mean squared error between x and y.

        x: tensor x
        y: tensor y

        the normalized MSE between x and y
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    mse = torch.mean(2 - 2 * (x * y).sum(dim=-1))
    return mse

# TODO: Move this to transforms
class RandomApply(Module):
    """Applies augmentation function (augm) with probability p."""

    def __init__(self, augm: Callable[[Tensor], Tensor], p: float) -> None:
        """Initialize RandomApply.

            augm: augmentation function to apply
            p: probability with which the augmentation function is applied
        self.augm = augm
        self.p = p

    def forward(self, x: Tensor) -> Tensor:
        """Applies an augmentation to the input with some probability.

            x: a batch of imagery

            augmented version of ``x`` with probability ``self.p`` else an un-augmented
        return x if random.random() > self.p else self.augm(x)

# TODO: This isn't _really_ applying the augmentations from SimCLR as we have
# multispectral imagery and thus can't naively apply color jittering or grayscale
# conversions. We should think more about what makes sense here.
class SimCLRAugmentation(Module):
    """A module for applying SimCLR augmentations.

    SimCLR was one of the first papers to show the effectiveness of random data
    augmentation in self-supervised-learning setups. See for more details.

    def __init__(self, image_size: Tuple[int, int] = (256, 256)) -> None:
        """Initialize a module for applying SimCLR augmentations.

            image_size: Tuple of integers defining the image size
        self.size = image_size

        self.augmentation = Sequential(
            KorniaTransform.Resize(size=image_size, align_corners=False),
            # Not suitable for multispectral adapt
            # RandomApply(K.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            # K.RandomGrayscale(p=0.2),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),

    def forward(self, x: Tensor) -> Tensor:
        """Applys SimCLR augmentations to the input tensor.

            x: a batch of imagery

            an augmented batch of imagery
        return cast(Tensor, self.augmentation(x))

class MLP(Module):
    """MLP used in the BYOL projection head."""

    def __init__(
        self, dim: int, projection_size: int = 256, hidden_size: int = 4096
    ) -> None:
        """Initializes the MLP projection head.

            dim: size of layer to project
            projection_size: size of the output layer
            hidden_size: size of the hidden layer
        self.mlp = Sequential(
            Linear(dim, hidden_size),
            Linear(hidden_size, projection_size),

    def forward(self, x: Tensor) -> Tensor:
        """Forward pass of the MLP model.

            x: batch of imagery

            embedded version of the input
        return cast(Tensor, self.mlp(x))

class EncoderWrapper(Module):
    """Encoder wrapper for joining a model and a projection head.

    When we call .forward() on this module the following steps happen:

    * The input is passed through the base model
    * When the encoding layer is reached a hook is called
    * The output of the encoding layer is passed through the projection head
    * The forward call returns the output of the projection head

    def __init__(
        model: Module,
        projection_size: int = 256,
        hidden_size: int = 4096,
        layer: int = -2,
    ) -> None:
        """Initializes EncoderWrapper.

            model: model to encode
            projection_size: size of the ouput layer of the projector MLP
            hidden_size: size of hidden layer of the projector MLP
            layer: layer from model to project

        self.model = model
        self.projection_size = projection_size
        self.hidden_size = hidden_size
        self.layer = layer

        self._projector: Optional[Module] = None
        self._projector_dim: Optional[int] = None
        self._encoded = torch.empty(0)

    def projector(self) -> Module:
        """Wrapper module for the projector head."""
        assert self._projector_dim is not None
        if self._projector is None:
            self._projector = MLP(
                self._projector_dim, self.projection_size, self.hidden_size
        return self._projector

    def _hook(self, module: Any, input: Any, output: Tensor) -> None:
        """Hook to record the activations at the projection layer.

        See the following docs page for more details on hooks:

            module: the calling module
            input: input to the module this hook was registered to
            output: output from the module this hook was registered to
        output = output.flatten(start_dim=1)
        if self._projector_dim is None:
            # If we haven't already, measure the output size
            self._projector_dim = output.shape[-1]

        # Project the output to get encodings, the projector model is created the first
        # time this is called
        self._encoded = self.projector(output)

    def _register_hook(self) -> None:
        """Register a hook for layer that we will extract features from."""
        layer = list(self.model.children())[self.layer]

    def forward(self, x: Tensor) -> Tensor:
        """Pass through the model, and collect the representation from our forward hook.

            x: tensor of data to run through the model

            output from the model
        _ = self.model(x)
        return self._encoded

class BYOL(Module):
    """BYOL implementation.

    BYOL contains two identical encoder networks. The first is trained as usual, and its
    weights are updated with each training batch. The second, "target" network, is
    updated using a running average of the first encoder's weights.

    See for more details (and please cite it if you
    use it in your own work).

    def __init__(
        model: Module,
        image_size: Tuple[int, int] = (256, 256),
        hidden_layer: int = -2,
        in_channels: int = 4,
        projection_size: int = 256,
        hidden_size: int = 4096,
        augment_fn: Optional[Module] = None,
        beta: float = 0.99,
        **kwargs: Any,
    ) -> None:
        """Sets up a model for pre-training with BYOL using projection heads.

            model: the model to pretrain using BYOL
            image_size: the size of the training images
            hidden_layer: the hidden layer in ``model`` to attach the projection
                head to, can be the name of the layer or index of the layer
            in_channels: number of input channels to the model
            projection_size: size of first layer of the projection MLP
            hidden_size: size of the hidden layer of the projection MLP
            augment_fn: an instance of a module that performs data augmentation
            beta: the speed at which the target encoder is updated using the main

        self.augment: Module
        if augment_fn is None:
            self.augment = SimCLRAugmentation(image_size)
            self.augment = augment_fn

        self.beta = beta
        self.in_channels = in_channels
        self.encoder = EncoderWrapper(
            model, projection_size, hidden_size, layer=hidden_layer
        self.predictor = MLP(projection_size, projection_size, hidden_size) = EncoderWrapper(
            model, projection_size, hidden_size, layer=hidden_layer

        # Perform a single forward pass to initialize the wrapper correctly
        self.encoder(torch.zeros(2, self.in_channels, *image_size))

    def forward(self, x: Tensor) -> Tensor:
        """Forward pass of the encoder model through the MLP and prediction head.

            x: tensor of data to run through the model

            output from the model
        return cast(Tensor, self.predictor(self.encoder(x)))

    def update_target(self) -> None:
        """Method to update the "target" model weights."""
        for p, pt in zip(self.encoder.parameters(),
   = self.beta * + (1 - self.beta) *

class BYOLTask(pl.LightningModule):
    """Class for pre-training any PyTorch model using BYOL."""

[docs] def config_task(self) -> None: """Configures the task based on kwargs parameters passed to the constructor.""" in_channels = self.hyperparams["in_channels"] pretrained = self.hyperparams["imagenet_pretraining"] encoder_name = self.hyperparams["encoder_name"] if parse(torchvision.__version__) >= parse("0.13"): if pretrained: kwargs = { "weights": getattr( torchvision.models, f"ResNet{encoder_name[6:]}_Weights" ).DEFAULT } else: kwargs = {"weights": None} else: kwargs = {"pretrained": pretrained} encoder = getattr(torchvision.models, encoder_name)(**kwargs) layer = encoder.conv1 # Creating new Conv2d layer new_layer = Conv2d( in_channels=in_channels, out_channels=layer.out_channels, kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=layer.bias, ).requires_grad_() # initialize the weights from new channel with the red channel weights copy_weights = 0 # Copying the weights from the old to the new layer new_layer.weight[:, : layer.in_channels, :, :].data[:] = Variable( layer.weight.clone(), requires_grad=True ) # Copying the weights of the old layer to the extra channels for i in range(in_channels - layer.in_channels): channel = layer.in_channels + i new_layer.weight[:, channel : channel + 1, :, :].data[:] = Variable( layer.weight[:, copy_weights : copy_weights + 1, ::].clone(), requires_grad=True, ) encoder.conv1 = new_layer self.model = BYOL(encoder, in_channels=in_channels, image_size=(256, 256))
[docs] def __init__(self, **kwargs: Any) -> None: """Initialize a LightningModule for pre-training a model with BYOL. Keyword Args: in_channels: number of channels on the input imagery encoder_name: either "resnet18" or "resnet50" imagenet_pretraining: bool indicating whether to use imagenet pretrained weights Raises: ValueError: if kwargs arguments are invalid """ super().__init__() # Creates `self.hparams` from kwargs self.save_hyperparameters() # type: ignore[operator] self.hyperparams = cast(Dict[str, Any], self.hparams) self.config_task()
[docs] def forward(self, *args: Any, **kwargs: Any) -> Any: """Forward pass of the model. Args: x: tensor of data to run through the model Returns: output from the model """ return self.model(*args, **kwargs)
[docs] def configure_optimizers(self) -> Dict[str, Any]: """Initialize the optimizer and learning rate scheduler. Returns: a "lr dict" according to the pytorch lightning documentation -- """ optimizer_class = getattr(optim, self.hyperparams.get("optimizer", "Adam")) lr = self.hyperparams.get("lr", 1e-4) weight_decay = self.hyperparams.get("weight_decay", 1e-6) optimizer = optimizer_class(self.parameters(), lr=lr, weight_decay=weight_decay) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau( optimizer, patience=self.hyperparams["learning_rate_schedule_patience"], ), "monitor": "val_loss", }, }
[docs] def training_step(self, *args: Any, **kwargs: Any) -> Tensor: """Compute and return the training loss. Args: batch: the output of your DataLoader Returns: training loss """ batch = args[0] x = batch["image"] with torch.no_grad(): x1, x2 = self.model.augment(x), self.model.augment(x) pred1, pred2 = self.forward(x1), self.forward(x2) with torch.no_grad(): targ1, targ2 =, loss = torch.mean(normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1)) self.log("train_loss", loss, on_step=True, on_epoch=False) self.model.update_target() return loss
[docs] def validation_step(self, *args: Any, **kwargs: Any) -> None: """Compute validation loss. Args: batch: the output of your DataLoader """ batch = args[0] x = batch["image"] x1, x2 = self.model.augment(x), self.model.augment(x) pred1, pred2 = self.forward(x1), self.forward(x2) targ1, targ2 =, loss = torch.mean(normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1)) self.log("val_loss", loss, on_step=False, on_epoch=True)
[docs] def test_step(self, *args: Any, **kwargs: Any) -> Any: """No-op, does nothing."""

