mirror of https://github.com/alibaba/EasyCV.git
1003 lines
35 KiB
Python
1003 lines
35 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/models/backbones/litehrnet.py
|
|
|
|
import mmcv
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.utils.checkpoint as cp
|
|
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
|
|
build_conv_layer, build_norm_layer, constant_init,
|
|
normal_init)
|
|
from torch.nn.modules.batchnorm import _BatchNorm
|
|
|
|
from easycv.models.registry import BACKBONES
|
|
|
|
|
|
def channel_shuffle(x, groups):
|
|
"""Channel Shuffle operation.
|
|
|
|
This function enables cross-group information flow for multiple groups
|
|
convolution layers.
|
|
|
|
Args:
|
|
x (Tensor): The input tensor.
|
|
groups (int): The number of groups to divide the input tensor
|
|
in the channel dimension.
|
|
|
|
Returns:
|
|
Tensor: The output tensor after channel shuffle operation.
|
|
"""
|
|
|
|
batch_size, num_channels, height, width = x.size()
|
|
assert (num_channels % groups == 0), ('num_channels should be '
|
|
'divisible by groups')
|
|
channels_per_group = num_channels // groups
|
|
|
|
x = x.view(batch_size, groups, channels_per_group, height, width)
|
|
x = torch.transpose(x, 1, 2).contiguous()
|
|
x = x.view(batch_size, -1, height, width)
|
|
|
|
return x
|
|
|
|
|
|
class SpatialWeighting(nn.Module):
|
|
"""Spatial weighting module.
|
|
|
|
Args:
|
|
channels (int): The channels of the module.
|
|
ratio (int): channel reduction ratio.
|
|
conv_cfg (dict): Config dict for convolution layer.
|
|
Default: None, which means using conv2d.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: None.
|
|
act_cfg (dict): Config dict for activation layer.
|
|
Default: (dict(type='ReLU'), dict(type='Sigmoid')).
|
|
The last ConvModule uses Sigmoid by default.
|
|
"""
|
|
|
|
def __init__(self,
|
|
channels,
|
|
ratio=16,
|
|
conv_cfg=None,
|
|
norm_cfg=None,
|
|
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))):
|
|
super().__init__()
|
|
if isinstance(act_cfg, dict):
|
|
act_cfg = (act_cfg, act_cfg)
|
|
assert len(act_cfg) == 2
|
|
assert mmcv.is_tuple_of(act_cfg, dict)
|
|
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
|
self.conv1 = ConvModule(
|
|
in_channels=channels,
|
|
out_channels=int(channels / ratio),
|
|
kernel_size=1,
|
|
stride=1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg[0])
|
|
self.conv2 = ConvModule(
|
|
in_channels=int(channels / ratio),
|
|
out_channels=channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg[1])
|
|
|
|
def forward(self, x):
|
|
out = self.global_avgpool(x)
|
|
out = self.conv1(out)
|
|
out = self.conv2(out)
|
|
return x * out
|
|
|
|
|
|
class CrossResolutionWeighting(nn.Module):
|
|
"""Cross-resolution channel weighting module.
|
|
|
|
Args:
|
|
channels (int): The channels of the module.
|
|
ratio (int): channel reduction ratio.
|
|
conv_cfg (dict): Config dict for convolution layer.
|
|
Default: None, which means using conv2d.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: None.
|
|
act_cfg (dict): Config dict for activation layer.
|
|
Default: (dict(type='ReLU'), dict(type='Sigmoid')).
|
|
The last ConvModule uses Sigmoid by default.
|
|
"""
|
|
|
|
def __init__(self,
|
|
channels,
|
|
ratio=16,
|
|
conv_cfg=None,
|
|
norm_cfg=None,
|
|
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))):
|
|
super().__init__()
|
|
if isinstance(act_cfg, dict):
|
|
act_cfg = (act_cfg, act_cfg)
|
|
assert len(act_cfg) == 2
|
|
assert mmcv.is_tuple_of(act_cfg, dict)
|
|
self.channels = channels
|
|
total_channel = sum(channels)
|
|
self.conv1 = ConvModule(
|
|
in_channels=total_channel,
|
|
out_channels=int(total_channel / ratio),
|
|
kernel_size=1,
|
|
stride=1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg[0])
|
|
self.conv2 = ConvModule(
|
|
in_channels=int(total_channel / ratio),
|
|
out_channels=total_channel,
|
|
kernel_size=1,
|
|
stride=1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg[1])
|
|
|
|
def forward(self, x):
|
|
mini_size = x[-1].size()[-2:]
|
|
out = [F.adaptive_avg_pool2d(s, mini_size) for s in x[:-1]] + [x[-1]]
|
|
out = torch.cat(out, dim=1)
|
|
out = self.conv1(out)
|
|
out = self.conv2(out)
|
|
out = torch.split(out, self.channels, dim=1)
|
|
out = [
|
|
s * F.interpolate(a, size=s.size()[-2:], mode='nearest')
|
|
for s, a in zip(x, out)
|
|
]
|
|
return out
|
|
|
|
|
|
class ConditionalChannelWeighting(nn.Module):
|
|
"""Conditional channel weighting block.
|
|
|
|
Args:
|
|
in_channels (int): The input channels of the block.
|
|
stride (int): Stride of the 3x3 convolution layer.
|
|
reduce_ratio (int): channel reduction ratio.
|
|
conv_cfg (dict): Config dict for convolution layer.
|
|
Default: None, which means using conv2d.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
|
memory while slowing down the training speed. Default: False.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
stride,
|
|
reduce_ratio,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN'),
|
|
with_cp=False):
|
|
super().__init__()
|
|
self.with_cp = with_cp
|
|
self.stride = stride
|
|
assert stride in [1, 2]
|
|
|
|
branch_channels = [channel // 2 for channel in in_channels]
|
|
|
|
self.cross_resolution_weighting = CrossResolutionWeighting(
|
|
branch_channels,
|
|
ratio=reduce_ratio,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg)
|
|
|
|
self.depthwise_convs = nn.ModuleList([
|
|
ConvModule(
|
|
channel,
|
|
channel,
|
|
kernel_size=3,
|
|
stride=self.stride,
|
|
padding=1,
|
|
groups=channel,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=None) for channel in branch_channels
|
|
])
|
|
|
|
self.spatial_weighting = nn.ModuleList([
|
|
SpatialWeighting(channels=channel, ratio=4)
|
|
for channel in branch_channels
|
|
])
|
|
|
|
def forward(self, x):
|
|
|
|
def _inner_forward(x):
|
|
x = [s.chunk(2, dim=1) for s in x]
|
|
x1 = [s[0] for s in x]
|
|
x2 = [s[1] for s in x]
|
|
|
|
x2 = self.cross_resolution_weighting(x2)
|
|
x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)]
|
|
x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)]
|
|
|
|
out = [torch.cat([s1, s2], dim=1) for s1, s2 in zip(x1, x2)]
|
|
out = [channel_shuffle(s, 2) for s in out]
|
|
|
|
return out
|
|
|
|
if self.with_cp and x.requires_grad:
|
|
out = cp.checkpoint(_inner_forward, x)
|
|
else:
|
|
out = _inner_forward(x)
|
|
|
|
return out
|
|
|
|
|
|
class Stem(nn.Module):
|
|
"""Stem network block.
|
|
|
|
Args:
|
|
in_channels (int): The input channels of the block.
|
|
stem_channels (int): Output channels of the stem layer.
|
|
out_channels (int): The output channels of the block.
|
|
expand_ratio (int): adjusts number of channels of the hidden layer
|
|
in InvertedResidual by this amount.
|
|
conv_cfg (dict): Config dict for convolution layer.
|
|
Default: None, which means using conv2d.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
|
memory while slowing down the training speed. Default: False.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
stem_channels,
|
|
out_channels,
|
|
expand_ratio,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN'),
|
|
with_cp=False):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.conv_cfg = conv_cfg
|
|
self.norm_cfg = norm_cfg
|
|
self.with_cp = with_cp
|
|
|
|
self.conv1 = ConvModule(
|
|
in_channels=in_channels,
|
|
out_channels=stem_channels,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=dict(type='ReLU'))
|
|
|
|
mid_channels = int(round(stem_channels * expand_ratio))
|
|
branch_channels = stem_channels // 2
|
|
if stem_channels == self.out_channels:
|
|
inc_channels = self.out_channels - branch_channels
|
|
else:
|
|
inc_channels = self.out_channels - stem_channels
|
|
|
|
self.branch1 = nn.Sequential(
|
|
ConvModule(
|
|
branch_channels,
|
|
branch_channels,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
groups=branch_channels,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=None),
|
|
ConvModule(
|
|
branch_channels,
|
|
inc_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=dict(type='ReLU')),
|
|
)
|
|
|
|
self.expand_conv = ConvModule(
|
|
branch_channels,
|
|
mid_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=dict(type='ReLU'))
|
|
self.depthwise_conv = ConvModule(
|
|
mid_channels,
|
|
mid_channels,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
groups=mid_channels,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=None)
|
|
self.linear_conv = ConvModule(
|
|
mid_channels,
|
|
branch_channels
|
|
if stem_channels == self.out_channels else stem_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=dict(type='ReLU'))
|
|
|
|
def forward(self, x):
|
|
|
|
def _inner_forward(x):
|
|
x = self.conv1(x)
|
|
x1, x2 = x.chunk(2, dim=1)
|
|
|
|
x2 = self.expand_conv(x2)
|
|
x2 = self.depthwise_conv(x2)
|
|
x2 = self.linear_conv(x2)
|
|
|
|
out = torch.cat((self.branch1(x1), x2), dim=1)
|
|
|
|
out = channel_shuffle(out, 2)
|
|
|
|
return out
|
|
|
|
if self.with_cp and x.requires_grad:
|
|
out = cp.checkpoint(_inner_forward, x)
|
|
else:
|
|
out = _inner_forward(x)
|
|
|
|
return out
|
|
|
|
|
|
class IterativeHead(nn.Module):
|
|
"""Extra iterative head for feature learning.
|
|
|
|
Args:
|
|
in_channels (int): The input channels of the block.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
"""
|
|
|
|
def __init__(self, in_channels, norm_cfg=dict(type='BN')):
|
|
super().__init__()
|
|
projects = []
|
|
num_branchs = len(in_channels)
|
|
self.in_channels = in_channels[::-1]
|
|
|
|
for i in range(num_branchs):
|
|
if i != num_branchs - 1:
|
|
projects.append(
|
|
DepthwiseSeparableConvModule(
|
|
in_channels=self.in_channels[i],
|
|
out_channels=self.in_channels[i + 1],
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=dict(type='ReLU'),
|
|
dw_act_cfg=None,
|
|
pw_act_cfg=dict(type='ReLU')))
|
|
else:
|
|
projects.append(
|
|
DepthwiseSeparableConvModule(
|
|
in_channels=self.in_channels[i],
|
|
out_channels=self.in_channels[i],
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=dict(type='ReLU'),
|
|
dw_act_cfg=None,
|
|
pw_act_cfg=dict(type='ReLU')))
|
|
self.projects = nn.ModuleList(projects)
|
|
|
|
def forward(self, x):
|
|
x = x[::-1]
|
|
|
|
y = []
|
|
last_x = None
|
|
for i, s in enumerate(x):
|
|
if last_x is not None:
|
|
last_x = F.interpolate(
|
|
last_x,
|
|
size=s.size()[-2:],
|
|
mode='bilinear',
|
|
align_corners=True)
|
|
s = s + last_x
|
|
s = self.projects[i](s)
|
|
y.append(s)
|
|
last_x = s
|
|
|
|
return y[::-1]
|
|
|
|
|
|
class ShuffleUnit(nn.Module):
|
|
"""InvertedResidual block for ShuffleNetV2 backbone.
|
|
|
|
Args:
|
|
in_channels (int): The input channels of the block.
|
|
out_channels (int): The output channels of the block.
|
|
stride (int): Stride of the 3x3 convolution layer. Default: 1
|
|
conv_cfg (dict): Config dict for convolution layer.
|
|
Default: None, which means using conv2d.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict): Config dict for activation layer.
|
|
Default: dict(type='ReLU').
|
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
|
memory while slowing down the training speed. Default: False.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
stride=1,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN'),
|
|
act_cfg=dict(type='ReLU'),
|
|
with_cp=False):
|
|
super().__init__()
|
|
self.stride = stride
|
|
self.with_cp = with_cp
|
|
|
|
branch_features = out_channels // 2
|
|
if self.stride == 1:
|
|
assert in_channels == branch_features * 2, (
|
|
f'in_channels ({in_channels}) should equal to '
|
|
f'branch_features * 2 ({branch_features * 2}) '
|
|
'when stride is 1')
|
|
|
|
if in_channels != branch_features * 2:
|
|
assert self.stride != 1, (
|
|
f'stride ({self.stride}) should not equal 1 when '
|
|
f'in_channels != branch_features * 2')
|
|
|
|
if self.stride > 1:
|
|
self.branch1 = nn.Sequential(
|
|
ConvModule(
|
|
in_channels,
|
|
in_channels,
|
|
kernel_size=3,
|
|
stride=self.stride,
|
|
padding=1,
|
|
groups=in_channels,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=None),
|
|
ConvModule(
|
|
in_channels,
|
|
branch_features,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg),
|
|
)
|
|
|
|
self.branch2 = nn.Sequential(
|
|
ConvModule(
|
|
in_channels if (self.stride > 1) else branch_features,
|
|
branch_features,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg),
|
|
ConvModule(
|
|
branch_features,
|
|
branch_features,
|
|
kernel_size=3,
|
|
stride=self.stride,
|
|
padding=1,
|
|
groups=branch_features,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=None),
|
|
ConvModule(
|
|
branch_features,
|
|
branch_features,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg))
|
|
|
|
def forward(self, x):
|
|
|
|
def _inner_forward(x):
|
|
if self.stride > 1:
|
|
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
|
|
else:
|
|
x1, x2 = x.chunk(2, dim=1)
|
|
out = torch.cat((x1, self.branch2(x2)), dim=1)
|
|
|
|
out = channel_shuffle(out, 2)
|
|
|
|
return out
|
|
|
|
if self.with_cp and x.requires_grad:
|
|
out = cp.checkpoint(_inner_forward, x)
|
|
else:
|
|
out = _inner_forward(x)
|
|
|
|
return out
|
|
|
|
|
|
class LiteHRModule(nn.Module):
|
|
"""High-Resolution Module for LiteHRNet.
|
|
|
|
It contains conditional channel weighting blocks and
|
|
shuffle blocks.
|
|
|
|
|
|
Args:
|
|
num_branches (int): Number of branches in the module.
|
|
num_blocks (int): Number of blocks in the module.
|
|
in_channels (list(int)): Number of input image channels.
|
|
reduce_ratio (int): Channel reduction ratio.
|
|
module_type (str): 'LITE' or 'NAIVE'
|
|
multiscale_output (bool): Whether to output multi-scale features.
|
|
with_fuse (bool): Whether to use fuse layers.
|
|
conv_cfg (dict): dictionary to construct and config conv layer.
|
|
norm_cfg (dict): dictionary to construct and config norm layer.
|
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
|
memory while slowing down the training speed.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_branches,
|
|
num_blocks,
|
|
in_channels,
|
|
reduce_ratio,
|
|
module_type,
|
|
multiscale_output=False,
|
|
with_fuse=True,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN'),
|
|
with_cp=False,
|
|
):
|
|
super().__init__()
|
|
self._check_branches(num_branches, in_channels)
|
|
|
|
self.in_channels = in_channels
|
|
self.num_branches = num_branches
|
|
|
|
self.module_type = module_type
|
|
self.multiscale_output = multiscale_output
|
|
self.with_fuse = with_fuse
|
|
self.norm_cfg = norm_cfg
|
|
self.conv_cfg = conv_cfg
|
|
self.with_cp = with_cp
|
|
|
|
if self.module_type.upper() == 'LITE':
|
|
self.layers = self._make_weighting_blocks(num_blocks, reduce_ratio)
|
|
elif self.module_type.upper() == 'NAIVE':
|
|
self.layers = self._make_naive_branches(num_branches, num_blocks)
|
|
else:
|
|
raise ValueError("module_type should be either 'LITE' or 'NAIVE'.")
|
|
if self.with_fuse:
|
|
self.fuse_layers = self._make_fuse_layers()
|
|
self.relu = nn.ReLU()
|
|
|
|
def _check_branches(self, num_branches, in_channels):
|
|
"""Check input to avoid ValueError."""
|
|
if num_branches != len(in_channels):
|
|
error_msg = f'NUM_BRANCHES({num_branches}) ' \
|
|
f'!= NUM_INCHANNELS({len(in_channels)})'
|
|
raise ValueError(error_msg)
|
|
|
|
def _make_weighting_blocks(self, num_blocks, reduce_ratio, stride=1):
|
|
"""Make channel weighting blocks."""
|
|
layers = []
|
|
for i in range(num_blocks):
|
|
layers.append(
|
|
ConditionalChannelWeighting(
|
|
self.in_channels,
|
|
stride=stride,
|
|
reduce_ratio=reduce_ratio,
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg,
|
|
with_cp=self.with_cp))
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def _make_one_branch(self, branch_index, num_blocks, stride=1):
|
|
"""Make one branch."""
|
|
layers = []
|
|
layers.append(
|
|
ShuffleUnit(
|
|
self.in_channels[branch_index],
|
|
self.in_channels[branch_index],
|
|
stride=stride,
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=dict(type='ReLU'),
|
|
with_cp=self.with_cp))
|
|
for i in range(1, num_blocks):
|
|
layers.append(
|
|
ShuffleUnit(
|
|
self.in_channels[branch_index],
|
|
self.in_channels[branch_index],
|
|
stride=1,
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=dict(type='ReLU'),
|
|
with_cp=self.with_cp))
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def _make_naive_branches(self, num_branches, num_blocks):
|
|
"""Make branches."""
|
|
branches = []
|
|
|
|
for i in range(num_branches):
|
|
branches.append(self._make_one_branch(i, num_blocks))
|
|
|
|
return nn.ModuleList(branches)
|
|
|
|
def _make_fuse_layers(self):
|
|
"""Make fuse layer."""
|
|
if self.num_branches == 1:
|
|
return None
|
|
|
|
num_branches = self.num_branches
|
|
in_channels = self.in_channels
|
|
fuse_layers = []
|
|
num_out_branches = num_branches if self.multiscale_output else 1
|
|
for i in range(num_out_branches):
|
|
fuse_layer = []
|
|
for j in range(num_branches):
|
|
if j > i:
|
|
fuse_layer.append(
|
|
nn.Sequential(
|
|
build_conv_layer(
|
|
self.conv_cfg,
|
|
in_channels[j],
|
|
in_channels[i],
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=False),
|
|
build_norm_layer(self.norm_cfg, in_channels[i])[1],
|
|
nn.Upsample(
|
|
scale_factor=2**(j - i), mode='nearest')))
|
|
elif j == i:
|
|
fuse_layer.append(None)
|
|
else:
|
|
conv_downsamples = []
|
|
for k in range(i - j):
|
|
if k == i - j - 1:
|
|
conv_downsamples.append(
|
|
nn.Sequential(
|
|
build_conv_layer(
|
|
self.conv_cfg,
|
|
in_channels[j],
|
|
in_channels[j],
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
groups=in_channels[j],
|
|
bias=False),
|
|
build_norm_layer(self.norm_cfg,
|
|
in_channels[j])[1],
|
|
build_conv_layer(
|
|
self.conv_cfg,
|
|
in_channels[j],
|
|
in_channels[i],
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=False),
|
|
build_norm_layer(self.norm_cfg,
|
|
in_channels[i])[1]))
|
|
else:
|
|
conv_downsamples.append(
|
|
nn.Sequential(
|
|
build_conv_layer(
|
|
self.conv_cfg,
|
|
in_channels[j],
|
|
in_channels[j],
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
groups=in_channels[j],
|
|
bias=False),
|
|
build_norm_layer(self.norm_cfg,
|
|
in_channels[j])[1],
|
|
build_conv_layer(
|
|
self.conv_cfg,
|
|
in_channels[j],
|
|
in_channels[j],
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=False),
|
|
build_norm_layer(self.norm_cfg,
|
|
in_channels[j])[1],
|
|
nn.ReLU(inplace=True)))
|
|
fuse_layer.append(nn.Sequential(*conv_downsamples))
|
|
fuse_layers.append(nn.ModuleList(fuse_layer))
|
|
|
|
return nn.ModuleList(fuse_layers)
|
|
|
|
def forward(self, x):
|
|
"""Forward function."""
|
|
if self.num_branches == 1:
|
|
return [self.layers[0](x[0])]
|
|
|
|
if self.module_type.upper() == 'LITE':
|
|
out = self.layers(x)
|
|
elif self.module_type.upper() == 'NAIVE':
|
|
for i in range(self.num_branches):
|
|
x[i] = self.layers[i](x[i])
|
|
out = x
|
|
|
|
if self.with_fuse:
|
|
out_fuse = []
|
|
for i in range(len(self.fuse_layers)):
|
|
# `y = 0` will lead to decreased accuracy (0.5~1 mAP)
|
|
y = out[0] if i == 0 else self.fuse_layers[i][0](out[0])
|
|
for j in range(self.num_branches):
|
|
if i == j:
|
|
y += out[j]
|
|
else:
|
|
y += self.fuse_layers[i][j](out[j])
|
|
out_fuse.append(self.relu(y))
|
|
out = out_fuse
|
|
if not self.multiscale_output:
|
|
out = [out[0]]
|
|
return out
|
|
|
|
|
|
@BACKBONES.register_module()
|
|
class LiteHRNet(nn.Module):
|
|
"""Lite-HRNet backbone.
|
|
|
|
`Lite-HRNet: A Lightweight High-Resolution Network
|
|
<https://arxiv.org/abs/2104.06403>`__
|
|
|
|
Code adapted from 'https://github.com/HRNet/Lite-HRNet/'
|
|
'blob/hrnet/models/backbones/litehrnet.py'
|
|
|
|
Args:
|
|
extra (dict): detailed configuration for each stage of HRNet.
|
|
in_channels (int): Number of input image channels. Default: 3.
|
|
conv_cfg (dict): dictionary to construct and config conv layer.
|
|
norm_cfg (dict): dictionary to construct and config norm layer.
|
|
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
|
freeze running stats (mean and var). Note: Effect on Batch Norm
|
|
and its variants only. Default: False
|
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
|
memory while slowing down the training speed.
|
|
|
|
Example:
|
|
>>> from mmpose.models import LiteHRNet
|
|
>>> import torch
|
|
>>> extra=dict(
|
|
>>> stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
|
|
>>> num_stages=3,
|
|
>>> stages_spec=dict(
|
|
>>> num_modules=(2, 4, 2),
|
|
>>> num_branches=(2, 3, 4),
|
|
>>> num_blocks=(2, 2, 2),
|
|
>>> module_type=('LITE', 'LITE', 'LITE'),
|
|
>>> with_fuse=(True, True, True),
|
|
>>> reduce_ratios=(8, 8, 8),
|
|
>>> num_channels=(
|
|
>>> (40, 80),
|
|
>>> (40, 80, 160),
|
|
>>> (40, 80, 160, 320),
|
|
>>> )),
|
|
>>> with_head=False)
|
|
>>> self = LiteHRNet(extra, in_channels=1)
|
|
>>> self.eval()
|
|
>>> inputs = torch.rand(1, 1, 32, 32)
|
|
>>> level_outputs = self.forward(inputs)
|
|
>>> for level_out in level_outputs:
|
|
... print(tuple(level_out.shape))
|
|
(1, 40, 8, 8)
|
|
"""
|
|
|
|
def __init__(self,
|
|
extra,
|
|
in_channels=3,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN'),
|
|
norm_eval=False,
|
|
with_cp=False):
|
|
super().__init__()
|
|
self.extra = extra
|
|
self.conv_cfg = conv_cfg
|
|
self.norm_cfg = norm_cfg
|
|
self.norm_eval = norm_eval
|
|
self.with_cp = with_cp
|
|
|
|
self.stem = Stem(
|
|
in_channels,
|
|
stem_channels=self.extra['stem']['stem_channels'],
|
|
out_channels=self.extra['stem']['out_channels'],
|
|
expand_ratio=self.extra['stem']['expand_ratio'],
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg)
|
|
|
|
self.num_stages = self.extra['num_stages']
|
|
self.stages_spec = self.extra['stages_spec']
|
|
|
|
num_channels_last = [
|
|
self.stem.out_channels,
|
|
]
|
|
for i in range(self.num_stages):
|
|
num_channels = self.stages_spec['num_channels'][i]
|
|
num_channels = [num_channels[i] for i in range(len(num_channels))]
|
|
setattr(
|
|
self, f'transition{i}',
|
|
self._make_transition_layer(num_channels_last, num_channels))
|
|
|
|
stage, num_channels_last = self._make_stage(
|
|
self.stages_spec, i, num_channels, multiscale_output=True)
|
|
setattr(self, f'stage{i}', stage)
|
|
|
|
self.with_head = self.extra['with_head']
|
|
if self.with_head:
|
|
self.head_layer = IterativeHead(
|
|
in_channels=num_channels_last,
|
|
norm_cfg=self.norm_cfg,
|
|
)
|
|
|
|
def _make_transition_layer(self, num_channels_pre_layer,
|
|
num_channels_cur_layer):
|
|
"""Make transition layer."""
|
|
num_branches_cur = len(num_channels_cur_layer)
|
|
num_branches_pre = len(num_channels_pre_layer)
|
|
|
|
transition_layers = []
|
|
for i in range(num_branches_cur):
|
|
if i < num_branches_pre:
|
|
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
|
transition_layers.append(
|
|
nn.Sequential(
|
|
build_conv_layer(
|
|
self.conv_cfg,
|
|
num_channels_pre_layer[i],
|
|
num_channels_pre_layer[i],
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
groups=num_channels_pre_layer[i],
|
|
bias=False),
|
|
build_norm_layer(self.norm_cfg,
|
|
num_channels_pre_layer[i])[1],
|
|
build_conv_layer(
|
|
self.conv_cfg,
|
|
num_channels_pre_layer[i],
|
|
num_channels_cur_layer[i],
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=False),
|
|
build_norm_layer(self.norm_cfg,
|
|
num_channels_cur_layer[i])[1],
|
|
nn.ReLU()))
|
|
else:
|
|
transition_layers.append(None)
|
|
else:
|
|
conv_downsamples = []
|
|
for j in range(i + 1 - num_branches_pre):
|
|
in_channels = num_channels_pre_layer[-1]
|
|
out_channels = num_channels_cur_layer[i] \
|
|
if j == i - num_branches_pre else in_channels
|
|
conv_downsamples.append(
|
|
nn.Sequential(
|
|
build_conv_layer(
|
|
self.conv_cfg,
|
|
in_channels,
|
|
in_channels,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
groups=in_channels,
|
|
bias=False),
|
|
build_norm_layer(self.norm_cfg, in_channels)[1],
|
|
build_conv_layer(
|
|
self.conv_cfg,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=False),
|
|
build_norm_layer(self.norm_cfg, out_channels)[1],
|
|
nn.ReLU()))
|
|
transition_layers.append(nn.Sequential(*conv_downsamples))
|
|
|
|
return nn.ModuleList(transition_layers)
|
|
|
|
def _make_stage(self,
|
|
stages_spec,
|
|
stage_index,
|
|
in_channels,
|
|
multiscale_output=True):
|
|
num_modules = stages_spec['num_modules'][stage_index]
|
|
num_branches = stages_spec['num_branches'][stage_index]
|
|
num_blocks = stages_spec['num_blocks'][stage_index]
|
|
reduce_ratio = stages_spec['reduce_ratios'][stage_index]
|
|
with_fuse = stages_spec['with_fuse'][stage_index]
|
|
module_type = stages_spec['module_type'][stage_index]
|
|
|
|
modules = []
|
|
for i in range(num_modules):
|
|
# multi_scale_output is only used last module
|
|
if not multiscale_output and i == num_modules - 1:
|
|
reset_multiscale_output = False
|
|
else:
|
|
reset_multiscale_output = True
|
|
|
|
modules.append(
|
|
LiteHRModule(
|
|
num_branches,
|
|
num_blocks,
|
|
in_channels,
|
|
reduce_ratio,
|
|
module_type,
|
|
multiscale_output=reset_multiscale_output,
|
|
with_fuse=with_fuse,
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg,
|
|
with_cp=self.with_cp))
|
|
in_channels = modules[-1].in_channels
|
|
|
|
return nn.Sequential(*modules), in_channels
|
|
|
|
def init_weights(self):
|
|
"""Initialize the weights in backbone.
|
|
|
|
Args:
|
|
pretrained (str, optional): Path to pre-trained weights.
|
|
Defaults to None.
|
|
"""
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
normal_init(m, std=0.001)
|
|
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
|
constant_init(m, 1)
|
|
|
|
def forward(self, x):
|
|
"""Forward function."""
|
|
x = self.stem(x)
|
|
|
|
y_list = [x]
|
|
for i in range(self.num_stages):
|
|
x_list = []
|
|
transition = getattr(self, f'transition{i}')
|
|
for j in range(self.stages_spec['num_branches'][i]):
|
|
if transition[j]:
|
|
if j >= len(y_list):
|
|
x_list.append(transition[j](y_list[-1]))
|
|
else:
|
|
x_list.append(transition[j](y_list[j]))
|
|
else:
|
|
x_list.append(y_list[j])
|
|
y_list = getattr(self, f'stage{i}')(x_list)
|
|
|
|
x = y_list
|
|
if self.with_head:
|
|
x = self.head_layer(x)
|
|
|
|
return [x[0]]
|
|
|
|
def train(self, mode=True):
|
|
"""Convert the model into training mode."""
|
|
super().train(mode)
|
|
if mode and self.norm_eval:
|
|
for m in self.modules():
|
|
if isinstance(m, _BatchNorm):
|
|
m.eval()
|