Shortcuts

Source code for torchgeo.models.api

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""APIs for querying and loading pre-trained model weights.

See the following references for design details:

* https://pytorch.org/blog/easily-list-and-initialize-models-with-new-apis-in-torchvision/
* https://pytorch.org/vision/stable/models.html
* https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py
"""  # noqa: E501

from collections.abc import Callable
from typing import Any

import torch.nn as nn
from torchvision.models._api import WeightsEnum

from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
from .swin import Swin_V2_B_Weights, swin_v2_b
from .vit import ViTSmall16_Weights, vit_small_patch16_224

_model = {
    "resnet18": resnet18,
    "resnet50": resnet50,
    "vit_small_patch16_224": vit_small_patch16_224,
    "swin_v2_b": swin_v2_b,
}

_model_weights = {
    resnet18: ResNet18_Weights,
    resnet50: ResNet50_Weights,
    vit_small_patch16_224: ViTSmall16_Weights,
    swin_v2_b: Swin_V2_B_Weights,
    "resnet18": ResNet18_Weights,
    "resnet50": ResNet50_Weights,
    "vit_small_patch16_224": ViTSmall16_Weights,
    "swin_v2_b": Swin_V2_B_Weights,
}


[docs]def get_model(name: str, *args: Any, **kwargs: Any) -> nn.Module: """Get an instantiated model from its name. .. versionadded:: 0.4 Args: name: Name of the model. *args: Additional arguments passed to the model builder method. **kwargs: Additional keyword arguments passed to the model builder method. Returns: An instantiated model. """ model: nn.Module = _model[name](*args, **kwargs) return model
[docs]def get_model_weights(name: Callable[..., nn.Module] | str) -> WeightsEnum: """Get the weights enum class associated with a given model. .. versionadded:: 0.4 Args: name: Model builder function or the name under which it is registered. Returns: The weights enum class associated with the model. """ return _model_weights[name]
[docs]def get_weight(name: str) -> WeightsEnum: """Get the weights enum value by its full name. .. versionadded:: 0.4 Args: name: Name of the weight enum entry. Returns: The requested weight enum. """ return eval(name)
[docs]def list_models() -> list[str]: """List the registered models. .. versionadded:: 0.4 Returns: A list of registered models. """ return list(_model.keys())

© Copyright 2021, Microsoft Corporation. Revision 1a2820e2.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
v0.5.2
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.1
v0.3.0
v0.2.1
v0.2.0
v0.1.1
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources