144 lines
4.5 KiB
Python
144 lines
4.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Optional
|
|
|
|
import torch.nn as nn
|
|
from mmcv.cnn import ConvModule
|
|
from mmengine.model import BaseModule
|
|
from torch import Tensor
|
|
|
|
from mmseg.registry import MODELS
|
|
from mmseg.utils import OptConfigType
|
|
|
|
|
|
class BasicBlock(BaseModule):
|
|
"""Basic block from `ResNet <https://arxiv.org/abs/1512.03385>`_.
|
|
|
|
Args:
|
|
in_channels (int): Input channels.
|
|
channels (int): Output channels.
|
|
stride (int): Stride of the first block. Default: 1.
|
|
downsample (nn.Module, optional): Downsample operation on identity.
|
|
Default: None.
|
|
norm_cfg (dict, optional): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict, optional): Config dict for activation layer in
|
|
ConvModule. Default: dict(type='ReLU', inplace=True).
|
|
act_cfg_out (dict, optional): Config dict for activation layer at the
|
|
last of the block. Default: None.
|
|
init_cfg (dict, optional): Initialization config dict. Default: None.
|
|
"""
|
|
|
|
expansion = 1
|
|
|
|
def __init__(self,
|
|
in_channels: int,
|
|
channels: int,
|
|
stride: int = 1,
|
|
downsample: nn.Module = None,
|
|
norm_cfg: OptConfigType = dict(type='BN'),
|
|
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
|
act_cfg_out: OptConfigType = dict(type='ReLU', inplace=True),
|
|
init_cfg: OptConfigType = None):
|
|
super().__init__(init_cfg)
|
|
self.conv1 = ConvModule(
|
|
in_channels,
|
|
channels,
|
|
kernel_size=3,
|
|
stride=stride,
|
|
padding=1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
self.conv2 = ConvModule(
|
|
channels,
|
|
channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=None)
|
|
self.downsample = downsample
|
|
if act_cfg_out:
|
|
self.act = MODELS.build(act_cfg_out)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
residual = x
|
|
out = self.conv1(x)
|
|
out = self.conv2(out)
|
|
|
|
if self.downsample:
|
|
residual = self.downsample(x)
|
|
|
|
out += residual
|
|
|
|
if hasattr(self, 'act'):
|
|
out = self.act(out)
|
|
|
|
return out
|
|
|
|
|
|
class Bottleneck(BaseModule):
|
|
"""Bottleneck block from `ResNet <https://arxiv.org/abs/1512.03385>`_.
|
|
|
|
Args:
|
|
in_channels (int): Input channels.
|
|
channels (int): Output channels.
|
|
stride (int): Stride of the first block. Default: 1.
|
|
downsample (nn.Module, optional): Downsample operation on identity.
|
|
Default: None.
|
|
norm_cfg (dict, optional): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict, optional): Config dict for activation layer in
|
|
ConvModule. Default: dict(type='ReLU', inplace=True).
|
|
act_cfg_out (dict, optional): Config dict for activation layer at
|
|
the last of the block. Default: None.
|
|
init_cfg (dict, optional): Initialization config dict. Default: None.
|
|
"""
|
|
|
|
expansion = 2
|
|
|
|
def __init__(self,
|
|
in_channels: int,
|
|
channels: int,
|
|
stride: int = 1,
|
|
downsample: Optional[nn.Module] = None,
|
|
norm_cfg: OptConfigType = dict(type='BN'),
|
|
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
|
act_cfg_out: OptConfigType = None,
|
|
init_cfg: OptConfigType = None):
|
|
super().__init__(init_cfg)
|
|
self.conv1 = ConvModule(
|
|
in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
|
self.conv2 = ConvModule(
|
|
channels,
|
|
channels,
|
|
3,
|
|
stride,
|
|
1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
self.conv3 = ConvModule(
|
|
channels,
|
|
channels * self.expansion,
|
|
1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=None)
|
|
if act_cfg_out:
|
|
self.act = MODELS.build(act_cfg_out)
|
|
self.downsample = downsample
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
residual = x
|
|
|
|
out = self.conv1(x)
|
|
out = self.conv2(out)
|
|
out = self.conv3(out)
|
|
|
|
if self.downsample:
|
|
residual = self.downsample(x)
|
|
|
|
out += residual
|
|
|
|
if hasattr(self, 'act'):
|
|
out = self.act(out)
|
|
|
|
return out
|