
Source code for torchgeo.models.dofa

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Dynamic One-For-All (DOFA) models."""

from functools import partial
from typing import Any

import kornia.augmentation as K
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from timm.models.vision_transformer import Block
from torch import Tensor
from torchvision.models._api import Weights, WeightsEnum

def position_embedding(embed_dim: int, pos: Tensor) -> Tensor:
    """Compute the 1D sine/cosine position embedding.

        embed_dim: Output dimension D for each position. Must be even.
        pos: A list of positions to be encoded, of size (M,).

        Position embeddings of size (M, D).

        AssertionError: If *embed_dim* is not even.
    assert embed_dim % 2 == 0
    omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
    omega /= embed_dim / 2.0
    omega = 1.0 / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = torch.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = torch.sin(out)  # (M, D/2)
    emb_cos = torch.cos(out)  # (M, D/2)

    emb =[emb_sin, emb_cos], dim=1)  # (M, D)
    return emb

class TransformerWeightGenerator(nn.Module):
    """Dynamic weight generator for DOFA."""

    def __init__(
        input_dim: int,
        output_dim: int,
        embed_dim: int,
        num_heads: int = 4,
        num_layers: int = 1,
    ) -> None:
        """Initialize a new TransformerWeightGenerator instance.

            input_dim: Input dimensions.
            output_dim: Output dimensions.
            embed_dim: Embedding dimensions.
            num_heads: Number of heads.
            num_layers: Number of layers.
        encoder_layer = nn.TransformerEncoderLayer(
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=num_layers, enable_nested_tensor=False

        # Linear layer to map transformer output to desired weight shape
        self.fc_weight = nn.Linear(input_dim, output_dim)
        self.fc_bias = nn.Linear(input_dim, embed_dim)
        self.wt_num = 128
        self.weight_tokens = nn.Parameter(torch.empty([self.wt_num, input_dim]))
        self.bias_token = nn.Parameter(torch.empty([1, input_dim]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is
        # too big (2.)
        torch.nn.init.normal_(self.weight_tokens, std=0.02)
        torch.nn.init.normal_(self.bias_token, std=0.02)

    def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
        """Forward pass of the model.

            x: Input mini-batch of size (seq_len, batch, input_dim).

            Weight and bias.
        pos_wave = x
        x =[self.weight_tokens, pos_wave], dim=0)
        x =[x, self.bias_token], dim=0)
        transformer_output = self.transformer_encoder(x)
        weights = self.fc_weight(transformer_output[self.wt_num : -1] + pos_wave)
        # Using the last output to generate bias
        bias = self.fc_bias(transformer_output[-1])
        return weights, bias

class FCResLayer(nn.Module):
    """Fully-connected residual layer."""

    def __init__(self, linear_size: int = 128) -> None:
        """Initialize a new FCResLayer instance.

            linear_size: Size of linear layer.
        self.l_size = linear_size
        self.nonlin1 = nn.ReLU(inplace=True)
        self.nonlin2 = nn.ReLU(inplace=True)
        self.w1 = nn.Linear(self.l_size, self.l_size)
        self.w2 = nn.Linear(self.l_size, self.l_size)

    def forward(self, x: Tensor) -> Tensor:
        """Forward pass of the model.

            x: Input mini-batch.

            Output of the model.
        y = self.w1(x)
        y = self.nonlin1(y)
        y = self.w2(y)
        y = self.nonlin2(y)
        out: Tensor = x + y
        return out

class DOFAEmbedding(nn.Module):
    """Dynamic One-For-All (DOFA) embedding."""

    def __init__(
        self, dynamic_embed_dim: int, kernel_size: int = 3, embed_dim: int = 1024
    ) -> None:
        """Initialize a new DOFAEmbedding instance.

            dynamic_embed_dim: Dimensions of dynamic weight generator.
            kernel_size: Kernel size of the depth-wise convolution.
            embed_dim: Embedding dimensions.
        self.dynamic_embed_dim = dynamic_embed_dim
        self.kernel_size = kernel_size
        self.embed_dim = embed_dim
        self._num_kernel = self.kernel_size * self.kernel_size * self.embed_dim
        self.patch_size = (kernel_size, kernel_size)
        self.num_patches = -1

        self.weight_generator = TransformerWeightGenerator(
            dynamic_embed_dim, self._num_kernel, embed_dim
        self.scaler = 0.01

        self.fclayer = FCResLayer(dynamic_embed_dim)


    def _init_weight(self, m: object) -> None:
        """Initialize weights of a single layer.

            m: A single layer.
        if isinstance(m, nn.Linear):

    def _init_weights(self) -> None:
        """Initialize weights of all layers."""

    def forward(self, x: Tensor, wavelengths: Tensor) -> tuple[Tensor, Tensor]:
        """Forward pass of the model.

            x: Input mini-batch.
            wavelengths: Wavelengths of each spectral band (μm).

            Output mini-batch and wavelengths.
        inplanes = wavelengths.size(0)
        # wv_feats: 9,128 -> 9, 3x3x3
        waves = position_embedding(self.dynamic_embed_dim, wavelengths * 1000)
        waves = self.fclayer(waves)
        weight, bias = self.weight_generator(waves)  # 3x3x3

        dynamic_weight = weight.view(
            self.embed_dim, inplanes, self.kernel_size, self.kernel_size
        )  # 3xoutdx16x16
        if bias is not None:
            bias = bias.view([self.embed_dim]) * self.scaler

        weights = dynamic_weight * self.scaler

        dynamic_out = F.conv2d(
            x, weights, bias=bias, stride=self.kernel_size, padding=1, dilation=1

        x = dynamic_out
        x = x.flatten(2).transpose(1, 2)

        return x, waves

[docs]class DOFA(nn.Module): """Dynamic One-For-All (DOFA) model. Reference implementation: * If you use this model in your research, please cite the following paper: * .. versionadded:: 0.6 """
[docs] def __init__( self, img_size: int = 224, patch_size: int = 16, drop_rate: float = 0.0, embed_dim: int = 1024, depth: int = 24, num_heads: int = 16, dynamic_embed_dim: int = 128, num_classes: int = 45, global_pool: bool = True, mlp_ratio: float = 4.0, norm_layer: type[nn.Module] = partial(nn.LayerNorm, eps=1e-6), # type: ignore[assignment] ) -> None: """Initialize a new DOFA instance. Args: img_size: Input image size. patch_size: Patch size. drop_rate: Head dropout rate. embed_dim: Transformer embedding dimension. depth: Depth of transformer. num_heads: Number of attention heads. dynamic_embed_dim: Dimensions of dynamic weight generator. num_classes: Number of classes for classification head. global_pool: Whether or not to perform global pooling. mlp_ratio: Ratio of MLP hidden dim to embedding dim. norm_layer: Normalization layer. """ super().__init__() self.dynamic_embed_dim = dynamic_embed_dim self.global_pool = global_pool if self.global_pool: norm_layer = norm_layer embed_dim = embed_dim self.fc_norm = norm_layer(embed_dim) else: self.norm = norm_layer(embed_dim) # -------------------------------------------------------------------------- # MAE encoder specifics self.patch_embed = DOFAEmbedding( dynamic_embed_dim=128, kernel_size=16, embed_dim=embed_dim ) self.num_patches = (img_size // patch_size) ** 2 self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # --------------------------------------------------------------------------- self.pos_embed = nn.Parameter( torch.zeros(1, self.num_patches + 1, embed_dim), requires_grad=False ) # fixed sin-cos embedding self.blocks = nn.ModuleList( [ Block( embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, ) for i in range(depth) ] ) self.head_drop = nn.Dropout(drop_rate) self.head = ( nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() )
[docs] def forward_features(self, x: Tensor, wavelengths: list[float]) -> Tensor: """Forward pass of the feature embedding layer. Args: x: Input mini-batch. wavelengths: Wavelengths of each spectral band (μm). Returns: Output mini-batch. """ # embed patches wavelist = torch.tensor(wavelengths, device=x.device).float() self.waves = wavelist x, _ = self.patch_embed(x, self.waves) x = x + self.pos_embed[:, 1:, :] # append cls token cls_token = self.cls_token + self.pos_embed[:, :1, :] cls_tokens = cls_token.expand(x.shape[0], -1, -1) x =, x), dim=1) # apply Transformer blocks for block in self.blocks: x = block(x) if self.global_pool: x = x[:, 1:, :].mean(dim=1) # global pool without cls token outcome: Tensor = self.fc_norm(x) else: x = self.norm(x) outcome = x[:, 0] return outcome
[docs] def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor: """Forward pass of the attention head. Args: x: Input mini-batch. pre_logits: Whether or not to return the layer before logits are computed. Returns: Output mini-batch. """ x = self.head_drop(x) x = x if pre_logits else self.head(x) return x
[docs] def forward(self, x: Tensor, wavelengths: list[float]) -> Tensor: """Forward pass of the model. Args: x: Input mini-batch. wavelengths: Wavelengths of each spectral band (μm). Returns: Output mini-batch. """ x = self.forward_features(x, wavelengths) x = self.forward_head(x) return x
# # Normalization is sensor-dependent and is therefore left out _dofa_transforms = K.AugmentationSequential(K.CenterCrop((224, 224)), data_keys=None) # # # Can be removed once torchvision>=0.15 is required Weights.__deepcopy__ = lambda *args, **kwargs: args[0]
[docs]class DOFABase16_Weights(WeightsEnum): # type: ignore[misc] """Dynamic One-For-All (DOFA) base patch size 16 weights. .. versionadded:: 0.6 """ DOFA_MAE = Weights( url='', # noqa: E501 transforms=_dofa_transforms, meta={ 'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k', 'model': 'dofa_base_patch16_224', 'publication': '', 'repo': '', 'ssl_method': 'mae', }, )
[docs]class DOFALarge16_Weights(WeightsEnum): # type: ignore[misc] """Dynamic One-For-All (DOFA) large patch size 16 weights. .. versionadded:: 0.6 """ DOFA_MAE = Weights( url='', # noqa: E501 transforms=_dofa_transforms, meta={ 'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k', 'model': 'dofa_large_patch16_224', 'publication': '', 'repo': '', 'ssl_method': 'mae', }, )
[docs]def dofa_small_patch16_224(**kwargs: Any) -> DOFA: """Dynamic One-For-All (DOFA) small patch size 16 model. If you use this model in your research, please cite the following paper: * .. versionadded:: 0.6 Args: **kwargs: Additional keywork arguments to pass to :class:`DOFA`. Returns: A DOFA small 16 model. """ model = DOFA(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) return model
[docs]def dofa_base_patch16_224( weights: DOFABase16_Weights | None = None, **kwargs: Any ) -> DOFA: """Dynamic One-For-All (DOFA) base patch size 16 model. If you use this model in your research, please cite the following paper: * .. versionadded:: 0.6 Args: weights: Pre-trained model weights to use. **kwargs: Additional keywork arguments to pass to :class:`DOFA`. Returns: A DOFA base 16 model. """ model = DOFA(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) if weights: missing_keys, unexpected_keys = model.load_state_dict( weights.get_state_dict(progress=True), strict=False ) # Both fc_norm and head are generated dynamically assert set(missing_keys) <= { 'fc_norm.weight', 'fc_norm.bias', 'head.weight', 'head.bias', } assert not unexpected_keys return model
[docs]def dofa_large_patch16_224( weights: DOFALarge16_Weights | None = None, **kwargs: Any ) -> DOFA: """Dynamic One-For-All (DOFA) large patch size 16 model. If you use this model in your research, please cite the following paper: * .. versionadded:: 0.6 Args: weights: Pre-trained model weights to use. **kwargs: Additional keywork arguments to pass to :class:`DOFA`. Returns: A DOFA large 16 model. """ model = DOFA(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) if weights: missing_keys, unexpected_keys = model.load_state_dict( weights.get_state_dict(progress=True), strict=False ) # Both fc_norm and head are generated dynamically assert set(missing_keys) <= { 'fc_norm.weight', 'fc_norm.bias', 'head.weight', 'head.bias', } assert not unexpected_keys return model
[docs]def dofa_huge_patch16_224(**kwargs: Any) -> DOFA: """Dynamic One-For-All (DOFA) huge patch size 16 model. If you use this model in your research, please cite the following paper: * .. versionadded:: 0.6 Args: **kwargs: Additional keywork arguments to pass to :class:`DOFA`. Returns: A DOFA huge 16 model. """ model = DOFA(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs) return model

© Copyright 2021, Microsoft Corporation. Revision 94bd5c76.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources