mirror of https://github.com/JDAI-CV/fast-reid.git
284 lines
11 KiB
Python
284 lines
11 KiB
Python
from functools import partial
|
|
from typing import Any, Callable, Dict, List, Optional, Sequence
|
|
|
|
import torch
|
|
from torch import nn, Tensor
|
|
from torch.nn import functional as F
|
|
|
|
#The style of importing Considers compatibility for the diversity of torchvision versions
|
|
try:
|
|
from torchvision.models.utils import load_state_dict_from_url
|
|
except ImportError:
|
|
try:
|
|
from torch.hub import load_state_dict_from_url
|
|
except ImportError:
|
|
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
|
|
|
from fastreid.layers import get_norm
|
|
from .build import BACKBONE_REGISTRY
|
|
from .mobilenet import _make_divisible
|
|
|
|
# https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py
|
|
|
|
model_urls = {
|
|
"Large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
|
|
"Small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
|
|
}
|
|
|
|
|
|
def conv_1x1_bn(inp, oup, bn_norm):
|
|
return nn.Sequential(
|
|
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
|
get_norm(bn_norm, oup),
|
|
nn.ReLU6(inplace=True)
|
|
)
|
|
|
|
|
|
class ConvBNActivation(nn.Sequential):
|
|
def __init__(
|
|
self,
|
|
in_planes: int,
|
|
out_planes: int,
|
|
kernel_size: int = 3,
|
|
stride: int = 1,
|
|
groups: int = 1,
|
|
bn_norm=None,
|
|
activation_layer: Optional[Callable[..., nn.Module]] = None,
|
|
dilation: int = 1,
|
|
) -> None:
|
|
padding = (kernel_size - 1) // 2 * dilation
|
|
if activation_layer is None:
|
|
activation_layer = nn.ReLU6
|
|
super(ConvBNActivation, self).__init__(
|
|
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups,
|
|
bias=False),
|
|
get_norm(bn_norm, out_planes),
|
|
activation_layer(inplace=True)
|
|
)
|
|
self.out_channels = out_planes
|
|
|
|
|
|
class SqueezeExcitation(nn.Module):
|
|
def __init__(self, input_channels: int, squeeze_factor: int = 4):
|
|
super().__init__()
|
|
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
|
|
self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
|
|
|
|
def _scale(self, input: Tensor, inplace: bool) -> Tensor:
|
|
scale = F.adaptive_avg_pool2d(input, 1)
|
|
scale = self.fc1(scale)
|
|
scale = self.relu(scale)
|
|
scale = self.fc2(scale)
|
|
return F.hardsigmoid(scale, inplace=inplace)
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
scale = self._scale(input, True)
|
|
return scale * input
|
|
|
|
|
|
class InvertedResidualConfig:
|
|
def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool,
|
|
activation: str, stride: int, dilation: int, width_mult: float):
|
|
self.input_channels = self.adjust_channels(input_channels, width_mult)
|
|
self.kernel = kernel
|
|
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
|
|
self.out_channels = self.adjust_channels(out_channels, width_mult)
|
|
self.use_se = use_se
|
|
self.use_hs = activation == "HS"
|
|
self.stride = stride
|
|
self.dilation = dilation
|
|
|
|
@staticmethod
|
|
def adjust_channels(channels: int, width_mult: float):
|
|
return _make_divisible(channels * width_mult, 8)
|
|
|
|
|
|
class InvertedResidual(nn.Module):
|
|
def __init__(self, cnf: InvertedResidualConfig, bn_norm,
|
|
se_layer: Callable[..., nn.Module] = SqueezeExcitation):
|
|
super().__init__()
|
|
if not (1 <= cnf.stride <= 2):
|
|
raise ValueError('illegal stride value')
|
|
|
|
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
|
|
|
|
layers: List[nn.Module] = []
|
|
activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
|
|
|
|
# expand
|
|
if cnf.expanded_channels != cnf.input_channels:
|
|
layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,
|
|
bn_norm=bn_norm, activation_layer=activation_layer))
|
|
|
|
# depthwise
|
|
stride = 1 if cnf.dilation > 1 else cnf.stride
|
|
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
|
|
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
|
|
bn_norm=bn_norm, activation_layer=activation_layer))
|
|
if cnf.use_se:
|
|
layers.append(se_layer(cnf.expanded_channels))
|
|
|
|
# project
|
|
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, bn_norm=bn_norm,
|
|
activation_layer=nn.Identity))
|
|
|
|
self.block = nn.Sequential(*layers)
|
|
self.out_channels = cnf.out_channels
|
|
self._is_cn = cnf.stride > 1
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
result = self.block(input)
|
|
if self.use_res_connect:
|
|
result += input
|
|
return result
|
|
|
|
|
|
class MobileNetV3(nn.Module):
|
|
def __init__(
|
|
self,
|
|
bn_norm,
|
|
inverted_residual_setting: List[InvertedResidualConfig],
|
|
last_channel: int,
|
|
block: Optional[Callable[..., nn.Module]] = None,
|
|
) -> None:
|
|
"""
|
|
MobileNet V3 main class
|
|
Args:
|
|
inverted_residual_setting (List[InvertedResidualConfig]): Network structure
|
|
last_channel (int): The number of channels on the penultimate layer
|
|
block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet
|
|
"""
|
|
super().__init__()
|
|
|
|
if not inverted_residual_setting:
|
|
raise ValueError("The inverted_residual_setting should not be empty")
|
|
elif not (isinstance(inverted_residual_setting, Sequence) and
|
|
all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])):
|
|
raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
|
|
|
|
if block is None:
|
|
block = InvertedResidual
|
|
|
|
layers: List[nn.Module] = []
|
|
|
|
# building first layer
|
|
firstconv_output_channels = inverted_residual_setting[0].input_channels
|
|
layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, bn_norm=bn_norm,
|
|
activation_layer=nn.Hardswish))
|
|
|
|
# building inverted residual blocks
|
|
for cnf in inverted_residual_setting:
|
|
layers.append(block(cnf, bn_norm))
|
|
|
|
# building last several layers
|
|
lastconv_input_channels = inverted_residual_setting[-1].out_channels
|
|
lastconv_output_channels = 6 * lastconv_input_channels
|
|
layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
|
|
bn_norm=bn_norm, activation_layer=nn.Hardswish))
|
|
|
|
self.features = nn.Sequential(*layers)
|
|
self.conv = conv_1x1_bn(lastconv_output_channels, last_channel, bn_norm)
|
|
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
nn.init.ones_(m.weight)
|
|
nn.init.zeros_(m.bias)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.normal_(m.weight, 0, 0.01)
|
|
nn.init.zeros_(m.bias)
|
|
|
|
def _forward_impl(self, x: Tensor) -> Tensor:
|
|
x = self.features(x)
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return self._forward_impl(x)
|
|
|
|
|
|
def _mobilenet_v3_conf(arch: str, params: Dict[str, Any]):
|
|
# non-public config parameters
|
|
reduce_divider = 2 if params.pop('_reduced_tail', False) else 1
|
|
dilation = 2 if params.pop('_dilated', False) else 1
|
|
width_mult = params.pop('_width_mult', 1.0)
|
|
|
|
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
|
|
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
|
|
|
|
if arch == "Large":
|
|
inverted_residual_setting = [
|
|
bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
|
|
bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1
|
|
bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
|
|
bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2
|
|
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
|
|
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
|
|
bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3
|
|
bneck_conf(80, 3, 200, 80, False, "HS", 1, 1),
|
|
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
|
|
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
|
|
bneck_conf(80, 3, 480, 112, True, "HS", 1, 1),
|
|
bneck_conf(112, 3, 672, 112, True, "HS", 1, 1),
|
|
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4
|
|
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
|
|
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
|
|
]
|
|
last_channel = adjust_channels(1280 // reduce_divider) # C5
|
|
elif arch == "Small":
|
|
inverted_residual_setting = [
|
|
bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1
|
|
bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2
|
|
bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
|
|
bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3
|
|
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
|
|
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
|
|
bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
|
|
bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
|
|
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4
|
|
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
|
|
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
|
|
]
|
|
last_channel = adjust_channels(1024 // reduce_divider) # C5
|
|
else:
|
|
raise ValueError("Unsupported model type {}".format(arch))
|
|
|
|
return inverted_residual_setting, last_channel
|
|
|
|
|
|
def _mobilenet_v3_model(
|
|
bn_norm,
|
|
depth: str,
|
|
pretrained: bool,
|
|
pretrain_path: str,
|
|
**kwargs: Any
|
|
):
|
|
inverted_residual_setting, last_channel = _mobilenet_v3_conf(depth, kwargs)
|
|
model = MobileNetV3(bn_norm, inverted_residual_setting, last_channel, **kwargs)
|
|
if pretrained:
|
|
if pretrain_path:
|
|
state_dict = torch.load(pretrain_path)
|
|
else:
|
|
if model_urls.get(depth, None) is None:
|
|
raise ValueError("No checkpoint is available for model type {}".format(depth))
|
|
state_dict = load_state_dict_from_url(model_urls[depth], progress=True)
|
|
model.load_state_dict(state_dict, strict=False)
|
|
return model
|
|
|
|
|
|
@BACKBONE_REGISTRY.register()
|
|
def build_mobilenetv3_backbone(cfg):
|
|
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
|
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
|
|
bn_norm = cfg.MODEL.BACKBONE.NORM
|
|
depth = cfg.MODEL.BACKBONE.DEPTH
|
|
|
|
model = _mobilenet_v3_model(bn_norm, depth, pretrain, pretrain_path)
|
|
|
|
return model
|