Source code for torchgeo.models.fcn
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Simple fully convolutional neural network (FCN) implementations."""
import torch.nn as nn
from torch import Tensor
from torch.nn.modules import Module
[docs]class FCN(Module):
"""A simple 5 layer FCN with leaky relus and 'same' padding."""
[docs] def __init__(self, in_channels: int, classes: int, num_filters: int = 64) -> None:
"""Initializes the 5 layer FCN model.
Args:
in_channels: Number of input channels that the model will expect
classes: Number of filters in the final layer
num_filters: Number of filters in each convolutional layer
"""
super().__init__()
conv1 = nn.modules.Conv2d(
in_channels, num_filters, kernel_size=3, stride=1, padding=1
)
conv2 = nn.modules.Conv2d(
num_filters, num_filters, kernel_size=3, stride=1, padding=1
)
conv3 = nn.modules.Conv2d(
num_filters, num_filters, kernel_size=3, stride=1, padding=1
)
conv4 = nn.modules.Conv2d(
num_filters, num_filters, kernel_size=3, stride=1, padding=1
)
conv5 = nn.modules.Conv2d(
num_filters, num_filters, kernel_size=3, stride=1, padding=1
)
self.backbone = nn.modules.Sequential(
conv1,
nn.modules.LeakyReLU(inplace=True),
conv2,
nn.modules.LeakyReLU(inplace=True),
conv3,
nn.modules.LeakyReLU(inplace=True),
conv4,
nn.modules.LeakyReLU(inplace=True),
conv5,
nn.modules.LeakyReLU(inplace=True),
)
self.last = nn.modules.Conv2d(
num_filters, classes, kernel_size=1, stride=1, padding=0
)
[docs] def forward(self, x: Tensor) -> Tensor:
"""Forward pass of the model."""
x = self.backbone(x)
x = self.last(x)
return x