mirror of https://github.com/alibaba/EasyCV.git
90 lines
3.2 KiB
Python
90 lines
3.2 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import torch.nn as nn
|
|
|
|
from .conv_module import build_conv_layer
|
|
from .norm import build_norm_layer
|
|
|
|
|
|
class ResLayer(nn.Sequential):
|
|
"""ResLayer to build ResNet style backbone.
|
|
Args:
|
|
block (nn.Module): Residual block used to build ResLayer.
|
|
num_blocks (int): Number of blocks.
|
|
in_channels (int): Input channels of this block.
|
|
out_channels (int): Output channels of this block.
|
|
expansion (int, optional): The expansion for BasicBlock/Bottleneck.
|
|
If not specified, it will firstly be obtained via
|
|
``block.expansion``. If the block has no attribute "expansion",
|
|
the following default values will be used: 1 for BasicBlock and
|
|
4 for Bottleneck. Default: None.
|
|
stride (int): stride of the first block. Default: 1.
|
|
avg_down (bool): Use AvgPool instead of stride conv when
|
|
downsampling in the bottleneck. Default: False
|
|
conv_cfg (dict, optional): dictionary to construct and config conv
|
|
layer. Default: None
|
|
norm_cfg (dict): dictionary to construct and config norm layer.
|
|
Default: dict(type='BN')
|
|
"""
|
|
|
|
def __init__(self,
|
|
block,
|
|
num_blocks,
|
|
in_channels,
|
|
out_channels,
|
|
expansion=None,
|
|
stride=1,
|
|
avg_down=False,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN'),
|
|
**kwargs):
|
|
self.block = block
|
|
self.expansion = 4
|
|
|
|
downsample = None
|
|
if stride != 1 or in_channels != out_channels:
|
|
downsample = []
|
|
conv_stride = stride
|
|
if avg_down and stride != 1:
|
|
conv_stride = 1
|
|
downsample.append(
|
|
nn.AvgPool2d(
|
|
kernel_size=stride,
|
|
stride=stride,
|
|
ceil_mode=True,
|
|
count_include_pad=False))
|
|
downsample.extend([
|
|
build_conv_layer(
|
|
conv_cfg,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=conv_stride,
|
|
bias=False),
|
|
build_norm_layer(norm_cfg, out_channels)[1]
|
|
])
|
|
downsample = nn.Sequential(*downsample)
|
|
|
|
layers = []
|
|
layers.append(
|
|
block(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
expansion=self.expansion,
|
|
stride=stride,
|
|
downsample=downsample,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
**kwargs))
|
|
in_channels = out_channels
|
|
for i in range(1, num_blocks):
|
|
layers.append(
|
|
block(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
expansion=self.expansion,
|
|
stride=1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
**kwargs))
|
|
super(ResLayer, self).__init__(*layers)
|