mirror of https://github.com/JDAI-CV/fast-reid.git
feat($modeling/backbones): add new backbones
add osnet, resnext and resnest backbone supportedpull/43/head
parent
b098b194ba
commit
e3ae03cc58
|
@ -7,5 +7,6 @@
|
|||
from .build import build_backbone, BACKBONE_REGISTRY
|
||||
|
||||
from .resnet import build_resnet_backbone
|
||||
# from .osnet import *
|
||||
# from .attention import ResidualAttentionNet_56
|
||||
from .osnet import build_osnet_backbone
|
||||
from .resnest import build_resnest_backbone
|
||||
from .resnext import build_resnext_backbone
|
|
@ -0,0 +1,493 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: liaoxingyu5@jd.com
|
||||
"""
|
||||
|
||||
# based on:
|
||||
# https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/models/osnet.py
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from .build import BACKBONE_REGISTRY
|
||||
|
||||
model_urls = {
|
||||
'osnet_x1_0':
|
||||
'https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY',
|
||||
'osnet_x0_75':
|
||||
'https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq',
|
||||
'osnet_x0_5':
|
||||
'https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i',
|
||||
'osnet_x0_25':
|
||||
'https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs',
|
||||
'osnet_ibn_x1_0':
|
||||
'https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l'
|
||||
}
|
||||
|
||||
|
||||
##########
|
||||
# Basic layers
|
||||
##########
|
||||
class ConvLayer(nn.Module):
|
||||
"""Convolution layer (conv + bn + relu)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=1,
|
||||
IN=False
|
||||
):
|
||||
super(ConvLayer, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=False,
|
||||
groups=groups
|
||||
)
|
||||
if IN:
|
||||
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
|
||||
else:
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class Conv1x1(nn.Module):
|
||||
"""1x1 convolution + bn + relu."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, groups=1):
|
||||
super(Conv1x1, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=groups
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class Conv1x1Linear(nn.Module):
|
||||
"""1x1 convolution + bn (w/o non-linearity)."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super(Conv1x1Linear, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, out_channels, 1, stride=stride, padding=0, bias=False
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class Conv3x3(nn.Module):
|
||||
"""3x3 convolution + bn + relu."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, groups=1):
|
||||
super(Conv3x3, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False,
|
||||
groups=groups
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class LightConv3x3(nn.Module):
|
||||
"""Lightweight 3x3 convolution.
|
||||
1x1 (linear) + dw 3x3 (nonlinear).
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(LightConv3x3, self).__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, out_channels, 1, stride=1, padding=0, bias=False
|
||||
)
|
||||
self.conv2 = nn.Conv2d(
|
||||
out_channels,
|
||||
out_channels,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
groups=out_channels
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
##########
|
||||
# Building blocks for omni-scale feature learning
|
||||
##########
|
||||
class ChannelGate(nn.Module):
|
||||
"""A mini-network that generates channel-wise gates conditioned on input tensor."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
num_gates=None,
|
||||
return_gates=False,
|
||||
gate_activation='sigmoid',
|
||||
reduction=16,
|
||||
layer_norm=False
|
||||
):
|
||||
super(ChannelGate, self).__init__()
|
||||
if num_gates is None:
|
||||
num_gates = in_channels
|
||||
self.return_gates = return_gates
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc1 = nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels // reduction,
|
||||
kernel_size=1,
|
||||
bias=True,
|
||||
padding=0
|
||||
)
|
||||
self.norm1 = None
|
||||
if layer_norm:
|
||||
self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc2 = nn.Conv2d(
|
||||
in_channels // reduction,
|
||||
num_gates,
|
||||
kernel_size=1,
|
||||
bias=True,
|
||||
padding=0
|
||||
)
|
||||
if gate_activation == 'sigmoid':
|
||||
self.gate_activation = nn.Sigmoid()
|
||||
elif gate_activation == 'relu':
|
||||
self.gate_activation = nn.ReLU(inplace=True)
|
||||
elif gate_activation == 'linear':
|
||||
self.gate_activation = None
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unknown gate activation: {}".format(gate_activation)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
input = x
|
||||
x = self.global_avgpool(x)
|
||||
x = self.fc1(x)
|
||||
if self.norm1 is not None:
|
||||
x = self.norm1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
if self.gate_activation is not None:
|
||||
x = self.gate_activation(x)
|
||||
if self.return_gates:
|
||||
return x
|
||||
return input * x
|
||||
|
||||
|
||||
class OSBlock(nn.Module):
|
||||
"""Omni-scale feature learning block."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
IN=False,
|
||||
bottleneck_reduction=4,
|
||||
**kwargs
|
||||
):
|
||||
super(OSBlock, self).__init__()
|
||||
mid_channels = out_channels // bottleneck_reduction
|
||||
self.conv1 = Conv1x1(in_channels, mid_channels)
|
||||
self.conv2a = LightConv3x3(mid_channels, mid_channels)
|
||||
self.conv2b = nn.Sequential(
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
)
|
||||
self.conv2c = nn.Sequential(
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
)
|
||||
self.conv2d = nn.Sequential(
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
)
|
||||
self.gate = ChannelGate(mid_channels)
|
||||
self.conv3 = Conv1x1Linear(mid_channels, out_channels)
|
||||
self.downsample = None
|
||||
if in_channels != out_channels:
|
||||
self.downsample = Conv1x1Linear(in_channels, out_channels)
|
||||
self.IN = None
|
||||
if IN:
|
||||
self.IN = nn.InstanceNorm2d(out_channels, affine=True)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x1 = self.conv1(x)
|
||||
x2a = self.conv2a(x1)
|
||||
x2b = self.conv2b(x1)
|
||||
x2c = self.conv2c(x1)
|
||||
x2d = self.conv2d(x1)
|
||||
x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d)
|
||||
x3 = self.conv3(x2)
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(identity)
|
||||
out = x3 + identity
|
||||
if self.IN is not None:
|
||||
out = self.IN(out)
|
||||
return F.relu(out)
|
||||
|
||||
|
||||
##########
|
||||
# Network architecture
|
||||
##########
|
||||
class OSNet(nn.Module):
|
||||
"""Omni-Scale Network.
|
||||
|
||||
Reference:
|
||||
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
|
||||
- Zhou et al. Learning Generalisable Omni-Scale Representations
|
||||
for Person Re-Identification. arXiv preprint, 2019.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
blocks,
|
||||
layers,
|
||||
channels,
|
||||
IN=False,
|
||||
**kwargs
|
||||
):
|
||||
super(OSNet, self).__init__()
|
||||
num_blocks = len(blocks)
|
||||
assert num_blocks == len(layers)
|
||||
assert num_blocks == len(channels) - 1
|
||||
|
||||
# convolutional backbone
|
||||
self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN)
|
||||
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
self.conv2 = self._make_layer(
|
||||
blocks[0],
|
||||
layers[0],
|
||||
channels[0],
|
||||
channels[1],
|
||||
reduce_spatial_size=True,
|
||||
IN=IN
|
||||
)
|
||||
self.conv3 = self._make_layer(
|
||||
blocks[1],
|
||||
layers[1],
|
||||
channels[1],
|
||||
channels[2],
|
||||
reduce_spatial_size=True
|
||||
)
|
||||
self.conv4 = self._make_layer(
|
||||
blocks[2],
|
||||
layers[2],
|
||||
channels[2],
|
||||
channels[3],
|
||||
reduce_spatial_size=False
|
||||
)
|
||||
self.conv5 = Conv1x1(channels[3], channels[3])
|
||||
|
||||
self._init_params()
|
||||
|
||||
def _make_layer(
|
||||
self,
|
||||
block,
|
||||
layer,
|
||||
in_channels,
|
||||
out_channels,
|
||||
reduce_spatial_size,
|
||||
IN=False
|
||||
):
|
||||
layers = []
|
||||
|
||||
layers.append(block(in_channels, out_channels, IN=IN))
|
||||
for i in range(1, layer):
|
||||
layers.append(block(out_channels, out_channels, IN=IN))
|
||||
|
||||
if reduce_spatial_size:
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
Conv1x1(out_channels, out_channels),
|
||||
nn.AvgPool2d(2, stride=2)
|
||||
)
|
||||
)
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu'
|
||||
)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.BatchNorm1d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
x = self.conv5(x)
|
||||
return x
|
||||
|
||||
|
||||
def init_pretrained_weights(model, key=''):
|
||||
"""Initializes model with pretrained weights.
|
||||
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
import os
|
||||
import errno
|
||||
import gdown
|
||||
from collections import OrderedDict
|
||||
import warnings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _get_torch_home():
|
||||
ENV_TORCH_HOME = 'TORCH_HOME'
|
||||
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
||||
DEFAULT_CACHE_DIR = '~/.cache'
|
||||
torch_home = os.path.expanduser(
|
||||
os.getenv(
|
||||
ENV_TORCH_HOME,
|
||||
os.path.join(
|
||||
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
|
||||
)
|
||||
)
|
||||
)
|
||||
return torch_home
|
||||
|
||||
torch_home = _get_torch_home()
|
||||
model_dir = os.path.join(torch_home, 'checkpoints')
|
||||
try:
|
||||
os.makedirs(model_dir)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EEXIST:
|
||||
# Directory already exists, ignore.
|
||||
pass
|
||||
else:
|
||||
# Unexpected OSError, re-raise.
|
||||
raise
|
||||
filename = key + '_imagenet.pth'
|
||||
cached_file = os.path.join(model_dir, filename)
|
||||
|
||||
if not os.path.exists(cached_file):
|
||||
gdown.download(model_urls[key], cached_file, quiet=False)
|
||||
|
||||
state_dict = torch.load(cached_file)
|
||||
model_dict = model.state_dict()
|
||||
new_state_dict = OrderedDict()
|
||||
matched_layers, discarded_layers = [], []
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith('module.'):
|
||||
k = k[7:] # discard module.
|
||||
|
||||
if k in model_dict and model_dict[k].size() == v.size():
|
||||
new_state_dict[k] = v
|
||||
matched_layers.append(k)
|
||||
else:
|
||||
discarded_layers.append(k)
|
||||
|
||||
model_dict.update(new_state_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
if len(matched_layers) == 0:
|
||||
warnings.warn(
|
||||
'The pretrained weights from "{}" cannot be loaded, '
|
||||
'please check the key names manually '
|
||||
'(** ignored and continue **)'.format(cached_file)
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
'Successfully loaded imagenet pretrained weights from "{}"'.
|
||||
format(cached_file)
|
||||
)
|
||||
if len(discarded_layers) > 0:
|
||||
logger.info(
|
||||
'** The following layers are discarded '
|
||||
'due to unmatched keys or layer size: {}'.
|
||||
format(discarded_layers)
|
||||
)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def build_osnet_backbone(cfg):
|
||||
"""
|
||||
Create a OSNet instance from config.
|
||||
Returns:
|
||||
OSNet: a :class:`OSNet` instance
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
||||
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
|
||||
|
||||
num_blocks_per_stage = [2, 2, 2]
|
||||
num_channels_per_stage = [64, 256, 384, 512]
|
||||
model = OSNet([OSBlock, OSBlock, OSBlock], num_blocks_per_stage, num_channels_per_stage, with_ibn)
|
||||
pretrain_key = 'osnet_ibn_x1_0' if with_ibn else 'osnet_x1_0'
|
||||
if pretrain:
|
||||
init_pretrained_weights(model, pretrain_key)
|
||||
return model
|
|
@ -0,0 +1,401 @@
|
|||
# encoding: utf-8
|
||||
# based on:
|
||||
# https://github.com/zhanghang1989/ResNeSt/blob/master/resnest/torch/resnest.py
|
||||
"""ResNeSt models"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
from ...layers import SplAtConv2d, IBN, Non_local
|
||||
|
||||
_url_format = 'https://hangzh.s3.amazonaws.com/encoding/models/{}-{}.pth'
|
||||
|
||||
_model_sha256 = {name: checksum for checksum, name in [
|
||||
('528c19ca', 'resnest50'),
|
||||
('22405ba7', 'resnest101'),
|
||||
('75117900', 'resnest200'),
|
||||
('0cc87c48', 'resnest269'),
|
||||
]}
|
||||
|
||||
|
||||
def short_hash(name):
|
||||
if name not in _model_sha256:
|
||||
raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
|
||||
return _model_sha256[name][:8]
|
||||
|
||||
|
||||
model_urls = {name: _url_format.format(name, short_hash(name)) for
|
||||
name in _model_sha256.keys()
|
||||
}
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
"""ResNet Bottleneck
|
||||
"""
|
||||
# pylint: disable=unused-argument
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, with_ibn=False, stride=1, downsample=None,
|
||||
radix=1, cardinality=1, bottleneck_width=64,
|
||||
avd=False, avd_first=False, dilation=1, is_first=False,
|
||||
rectified_conv=False, rectify_avg=False,
|
||||
norm_layer=None, dropblock_prob=0.0, last_gamma=False):
|
||||
super(Bottleneck, self).__init__()
|
||||
group_width = int(planes * (bottleneck_width / 64.)) * cardinality
|
||||
self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
|
||||
if with_ibn:
|
||||
self.bn1 = IBN(group_width)
|
||||
else:
|
||||
self.bn1 = norm_layer(group_width)
|
||||
self.dropblock_prob = dropblock_prob
|
||||
self.radix = radix
|
||||
self.avd = avd and (stride > 1 or is_first)
|
||||
self.avd_first = avd_first
|
||||
|
||||
if self.avd:
|
||||
self.avd_layer = nn.AvgPool2d(3, stride, padding=1)
|
||||
stride = 1
|
||||
|
||||
if radix > 1:
|
||||
self.conv2 = SplAtConv2d(
|
||||
group_width, group_width, kernel_size=3,
|
||||
stride=stride, padding=dilation,
|
||||
dilation=dilation, groups=cardinality, bias=False,
|
||||
radix=radix, rectify=rectified_conv,
|
||||
rectify_avg=rectify_avg,
|
||||
norm_layer=norm_layer,
|
||||
dropblock_prob=dropblock_prob)
|
||||
elif rectified_conv:
|
||||
from rfconv import RFConv2d
|
||||
self.conv2 = RFConv2d(
|
||||
group_width, group_width, kernel_size=3, stride=stride,
|
||||
padding=dilation, dilation=dilation,
|
||||
groups=cardinality, bias=False,
|
||||
average_mode=rectify_avg)
|
||||
self.bn2 = norm_layer(group_width)
|
||||
else:
|
||||
self.conv2 = nn.Conv2d(
|
||||
group_width, group_width, kernel_size=3, stride=stride,
|
||||
padding=dilation, dilation=dilation,
|
||||
groups=cardinality, bias=False)
|
||||
self.bn2 = norm_layer(group_width)
|
||||
|
||||
self.conv3 = nn.Conv2d(
|
||||
group_width, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = norm_layer(planes * 4)
|
||||
|
||||
if last_gamma:
|
||||
from torch.nn.init import zeros_
|
||||
zeros_(self.bn3.weight)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.dilation = dilation
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
if self.dropblock_prob > 0.0:
|
||||
out = self.dropblock1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
if self.avd and self.avd_first:
|
||||
out = self.avd_layer(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
if self.radix == 1:
|
||||
out = self.bn2(out)
|
||||
if self.dropblock_prob > 0.0:
|
||||
out = self.dropblock2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
if self.avd and not self.avd_first:
|
||||
out = self.avd_layer(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
if self.dropblock_prob > 0.0:
|
||||
out = self.dropblock3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNest(nn.Module):
|
||||
"""ResNet Variants ResNest
|
||||
Parameters
|
||||
----------
|
||||
block : Block
|
||||
Class for the residual block. Options are BasicBlockV1, BottleneckV1.
|
||||
layers : list of int
|
||||
Numbers of layers in each block
|
||||
classes : int, default 1000
|
||||
Number of classification classes.
|
||||
dilated : bool, default False
|
||||
Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
|
||||
typically used in Semantic Segmentation.
|
||||
norm_layer : object
|
||||
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
|
||||
for Synchronized Cross-GPU BachNormalization).
|
||||
Reference:
|
||||
- He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
|
||||
- Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-variable
|
||||
def __init__(self, last_stride, with_ibn, with_nl, block, layers, non_layers, radix=1, groups=1,
|
||||
bottleneck_width=64,
|
||||
dilated=False, dilation=1,
|
||||
deep_stem=False, stem_width=64, avg_down=False,
|
||||
rectified_conv=False, rectify_avg=False,
|
||||
avd=False, avd_first=False,
|
||||
final_drop=0.0, dropblock_prob=0,
|
||||
last_gamma=False, norm_layer=nn.BatchNorm2d):
|
||||
self.cardinality = groups
|
||||
self.bottleneck_width = bottleneck_width
|
||||
# ResNet-D params
|
||||
self.inplanes = stem_width * 2 if deep_stem else 64
|
||||
self.avg_down = avg_down
|
||||
self.last_gamma = last_gamma
|
||||
# ResNeSt params
|
||||
self.radix = radix
|
||||
self.avd = avd
|
||||
self.avd_first = avd_first
|
||||
|
||||
super().__init__()
|
||||
self.rectified_conv = rectified_conv
|
||||
self.rectify_avg = rectify_avg
|
||||
if rectified_conv:
|
||||
from rfconv import RFConv2d
|
||||
conv_layer = RFConv2d
|
||||
else:
|
||||
conv_layer = nn.Conv2d
|
||||
conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {}
|
||||
if deep_stem:
|
||||
self.conv1 = nn.Sequential(
|
||||
conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs),
|
||||
norm_layer(stem_width),
|
||||
nn.ReLU(inplace=True),
|
||||
conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
|
||||
norm_layer(stem_width),
|
||||
nn.ReLU(inplace=True),
|
||||
conv_layer(stem_width, stem_width * 2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
|
||||
)
|
||||
else:
|
||||
self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False, **conv_kwargs)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], with_ibn=with_ibn, norm_layer=norm_layer, is_first=False)
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, with_ibn=with_ibn, norm_layer=norm_layer)
|
||||
if dilated or dilation == 4:
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, with_ibn=with_ibn,
|
||||
dilation=2, norm_layer=norm_layer,
|
||||
dropblock_prob=dropblock_prob)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, with_ibn=with_ibn,
|
||||
dilation=4, norm_layer=norm_layer,
|
||||
dropblock_prob=dropblock_prob)
|
||||
elif dilation == 2:
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, with_ibn=with_ibn,
|
||||
dilation=1, norm_layer=norm_layer,
|
||||
dropblock_prob=dropblock_prob)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, with_ibn=with_ibn,
|
||||
dilation=2, norm_layer=norm_layer,
|
||||
dropblock_prob=dropblock_prob)
|
||||
else:
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, with_ibn=with_ibn,
|
||||
norm_layer=norm_layer,
|
||||
dropblock_prob=dropblock_prob)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride, with_ibn=with_ibn,
|
||||
norm_layer=norm_layer,
|
||||
dropblock_prob=dropblock_prob)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, norm_layer):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
if with_nl:
|
||||
self._build_nonlocal(layers, non_layers)
|
||||
else:
|
||||
self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False, dilation=1, norm_layer=None,
|
||||
dropblock_prob=0.0, is_first=True):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
down_layers = []
|
||||
if self.avg_down:
|
||||
if dilation == 1:
|
||||
down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride,
|
||||
ceil_mode=True, count_include_pad=False))
|
||||
else:
|
||||
down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1,
|
||||
ceil_mode=True, count_include_pad=False))
|
||||
down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=1, bias=False))
|
||||
else:
|
||||
down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False))
|
||||
down_layers.append(norm_layer(planes * block.expansion))
|
||||
downsample = nn.Sequential(*down_layers)
|
||||
|
||||
layers = []
|
||||
if planes == 512:
|
||||
with_ibn = False
|
||||
if dilation == 1 or dilation == 2:
|
||||
layers.append(block(self.inplanes, planes, with_ibn, stride, downsample=downsample,
|
||||
radix=self.radix, cardinality=self.cardinality,
|
||||
bottleneck_width=self.bottleneck_width,
|
||||
avd=self.avd, avd_first=self.avd_first,
|
||||
dilation=1, is_first=is_first, rectified_conv=self.rectified_conv,
|
||||
rectify_avg=self.rectify_avg,
|
||||
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
|
||||
last_gamma=self.last_gamma))
|
||||
elif dilation == 4:
|
||||
layers.append(block(self.inplanes, planes, with_ibn, stride, downsample=downsample,
|
||||
radix=self.radix, cardinality=self.cardinality,
|
||||
bottleneck_width=self.bottleneck_width,
|
||||
avd=self.avd, avd_first=self.avd_first,
|
||||
dilation=2, is_first=is_first, rectified_conv=self.rectified_conv,
|
||||
rectify_avg=self.rectify_avg,
|
||||
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
|
||||
last_gamma=self.last_gamma))
|
||||
else:
|
||||
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
|
||||
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, with_ibn,
|
||||
radix=self.radix, cardinality=self.cardinality,
|
||||
bottleneck_width=self.bottleneck_width,
|
||||
avd=self.avd, avd_first=self.avd_first,
|
||||
dilation=dilation, rectified_conv=self.rectified_conv,
|
||||
rectify_avg=self.rectify_avg,
|
||||
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
|
||||
last_gamma=self.last_gamma))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _build_nonlocal(self, layers, non_layers):
|
||||
self.NL_1 = nn.ModuleList(
|
||||
[Non_local(256) for _ in range(non_layers[0])])
|
||||
self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
|
||||
self.NL_2 = nn.ModuleList(
|
||||
[Non_local(512) for _ in range(non_layers[1])])
|
||||
self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
|
||||
self.NL_3 = nn.ModuleList(
|
||||
[Non_local(1024) for _ in range(non_layers[2])])
|
||||
self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
|
||||
self.NL_4 = nn.ModuleList(
|
||||
[Non_local(2048) for _ in range(non_layers[3])])
|
||||
self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
NL1_counter = 0
|
||||
if len(self.NL_1_idx) == 0:
|
||||
self.NL_1_idx = [-1]
|
||||
for i in range(len(self.layer1)):
|
||||
x = self.layer1[i](x)
|
||||
if i == self.NL_1_idx[NL1_counter]:
|
||||
_, C, H, W = x.shape
|
||||
x = self.NL_1[NL1_counter](x)
|
||||
NL1_counter += 1
|
||||
# Layer 2
|
||||
NL2_counter = 0
|
||||
if len(self.NL_2_idx) == 0:
|
||||
self.NL_2_idx = [-1]
|
||||
for i in range(len(self.layer2)):
|
||||
x = self.layer2[i](x)
|
||||
if i == self.NL_2_idx[NL2_counter]:
|
||||
_, C, H, W = x.shape
|
||||
x = self.NL_2[NL2_counter](x)
|
||||
NL2_counter += 1
|
||||
# Layer 3
|
||||
NL3_counter = 0
|
||||
if len(self.NL_3_idx) == 0:
|
||||
self.NL_3_idx = [-1]
|
||||
for i in range(len(self.layer3)):
|
||||
x = self.layer3[i](x)
|
||||
if i == self.NL_3_idx[NL3_counter]:
|
||||
_, C, H, W = x.shape
|
||||
x = self.NL_3[NL3_counter](x)
|
||||
NL3_counter += 1
|
||||
# Layer 4
|
||||
NL4_counter = 0
|
||||
if len(self.NL_4_idx) == 0:
|
||||
self.NL_4_idx = [-1]
|
||||
for i in range(len(self.layer4)):
|
||||
x = self.layer4[i](x)
|
||||
if i == self.NL_4_idx[NL4_counter]:
|
||||
_, C, H, W = x.shape
|
||||
x = self.NL_4[NL4_counter](x)
|
||||
NL4_counter += 1
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def build_resnest_backbone(cfg):
|
||||
"""
|
||||
Create a ResNest instance from config.
|
||||
Returns:
|
||||
ResNet: a :class:`ResNet` instance.
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
||||
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
|
||||
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
|
||||
with_se = cfg.MODEL.BACKBONE.WITH_SE
|
||||
with_nl = cfg.MODEL.BACKBONE.WITH_NL
|
||||
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||
|
||||
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 200: [3, 24, 36, 3], 269: [3, 30, 48, 8]}[depth]
|
||||
nl_layers_per_stage = {50: [0, 2, 3, 0], 101: [0, 2, 3, 0]}[depth]
|
||||
stem_width = {50: 32, 101: 64, 200: 64, 269: 64}[depth]
|
||||
model = ResNest(last_stride, with_ibn, with_nl, Bottleneck, num_blocks_per_stage, nl_layers_per_stage,
|
||||
radix=2, groups=1, bottleneck_width=64,
|
||||
deep_stem=True, stem_width=stem_width, avg_down=True,
|
||||
avd=True, avd_first=False)
|
||||
if pretrain:
|
||||
# if not with_ibn:
|
||||
# original resnet
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
model_urls['resnest' + str(depth)], progress=True, check_hash=True)
|
||||
# else:
|
||||
# raise KeyError('Not implementation ibn in resnest')
|
||||
# # ibn resnet
|
||||
# state_dict = torch.load(pretrain_path)['state_dict']
|
||||
# # remove module in name
|
||||
# new_state_dict = {}
|
||||
# for k in state_dict:
|
||||
# new_k = '.'.join(k.split('.')[1:])
|
||||
# if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
|
||||
# new_state_dict[new_k] = state_dict[k]
|
||||
# state_dict = new_state_dict
|
||||
res = model.load_state_dict(state_dict, strict=False)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info('missing keys is {}'.format(res.missing_keys))
|
||||
logger.info('unexpected keys is {}'.format(res.unexpected_keys))
|
||||
return model
|
|
@ -0,0 +1,198 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: liaoxingyu5@jd.com
|
||||
"""
|
||||
|
||||
# based on:
|
||||
# https://github.com/XingangPan/IBN-Net/blob/master/models/imagenet/resnext_ibn_a.py
|
||||
|
||||
import math
|
||||
import logging
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import init
|
||||
import torch
|
||||
from ...layers import IBN
|
||||
from .build import BACKBONE_REGISTRY
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
"""
|
||||
RexNeXt bottleneck type C
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, with_ibn, baseWidth, cardinality, stride=1, downsample=None):
|
||||
""" Constructor
|
||||
Args:
|
||||
inplanes: input channel dimensionality
|
||||
planes: output channel dimensionality
|
||||
baseWidth: base width.
|
||||
cardinality: num of convolution groups.
|
||||
stride: conv stride. Replaces pooling layer.
|
||||
"""
|
||||
super(Bottleneck, self).__init__()
|
||||
|
||||
D = int(math.floor(planes * (baseWidth / 64)))
|
||||
C = cardinality
|
||||
self.conv1 = nn.Conv2d(inplanes, D * C, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
if with_ibn:
|
||||
self.bn1 = IBN(D * C)
|
||||
else:
|
||||
self.bn1 = nn.BatchNorm2d(D * C)
|
||||
self.conv2 = nn.Conv2d(D * C, D * C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(D * C)
|
||||
self.conv3 = nn.Conv2d(D * C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNeXt(nn.Module):
|
||||
"""
|
||||
ResNext optimized for the ImageNet dataset, as specified in
|
||||
https://arxiv.org/pdf/1611.05431.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, last_stride, with_ibn, block, layers, baseWidth=4, cardinality=32):
|
||||
""" Constructor
|
||||
Args:
|
||||
baseWidth: baseWidth for ResNeXt.
|
||||
cardinality: number of convolution groups.
|
||||
layers: config of layers, e.g., [3, 4, 6, 3]
|
||||
num_classes: number of classes
|
||||
"""
|
||||
super(ResNeXt, self).__init__()
|
||||
|
||||
self.cardinality = cardinality
|
||||
self.baseWidth = baseWidth
|
||||
self.inplanes = 64
|
||||
self.output_size = 64
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], with_ibn=with_ibn)
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, with_ibn=with_ibn)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, with_ibn=with_ibn)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride, with_ibn=with_ibn)
|
||||
|
||||
self.random_init()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False):
|
||||
""" Stack n bottleneck modules where n is inferred from the depth of the network.
|
||||
Args:
|
||||
block: block type used to construct ResNext
|
||||
planes: number of output channels (need to multiply by block.expansion)
|
||||
blocks: number of blocks to be built
|
||||
stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
|
||||
Returns: a Module consisting of n sequential bottlenecks.
|
||||
"""
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
if planes == 512:
|
||||
with_ibn = False
|
||||
layers.append(block(self.inplanes, planes, with_ibn, self.baseWidth, self.cardinality, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, with_ibn, self.baseWidth, self.cardinality, 1, None))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool1(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
return x
|
||||
|
||||
def random_init(self):
|
||||
self.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64)))
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.InstanceNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def build_resnext_backbone(cfg):
|
||||
"""
|
||||
Create a ResNeXt instance from config.
|
||||
Returns:
|
||||
ResNeXt: a :class:`ResNeXt` instance.
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
||||
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
|
||||
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
|
||||
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
|
||||
with_se = cfg.MODEL.BACKBONE.WITH_SE
|
||||
with_nl = cfg.MODEL.BACKBONE.WITH_NL
|
||||
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||
|
||||
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], }[depth]
|
||||
nl_layers_per_stage = {50: [0, 2, 3, 0], 101: [0, 2, 3, 0]}[depth]
|
||||
model = ResNeXt(last_stride, with_ibn, Bottleneck, num_blocks_per_stage)
|
||||
if pretrain:
|
||||
# if not with_ibn:
|
||||
# original resnet
|
||||
# state_dict = model_zoo.load_url(model_urls[depth])
|
||||
# else:
|
||||
# ibn resnet
|
||||
state_dict = torch.load(pretrain_path)['state_dict']
|
||||
# remove module in name
|
||||
new_state_dict = {}
|
||||
for k in state_dict:
|
||||
new_k = '.'.join(k.split('.')[1:])
|
||||
if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
|
||||
new_state_dict[new_k] = state_dict[k]
|
||||
state_dict = new_state_dict
|
||||
res = model.load_state_dict(state_dict, strict=False)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info('missing keys is {}'.format(res.missing_keys))
|
||||
logger.info('unexpected keys is {}'.format(res.unexpected_keys))
|
||||
return model
|
Loading…
Reference in New Issue