feat($modeling/backbones): add new backbones

add osnet, resnext and resnest backbone supported
pull/43/head
liaoxingyu 2020-04-24 12:14:56 +08:00
parent b098b194ba
commit e3ae03cc58
4 changed files with 1095 additions and 2 deletions

View File

@ -7,5 +7,6 @@
from .build import build_backbone, BACKBONE_REGISTRY
from .resnet import build_resnet_backbone
# from .osnet import *
# from .attention import ResidualAttentionNet_56
from .osnet import build_osnet_backbone
from .resnest import build_resnest_backbone
from .resnext import build_resnext_backbone

View File

@ -0,0 +1,493 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: liaoxingyu5@jd.com
"""
# based on:
# https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/models/osnet.py
import torch
from torch import nn
from torch.nn import functional as F
from .build import BACKBONE_REGISTRY
model_urls = {
'osnet_x1_0':
'https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY',
'osnet_x0_75':
'https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq',
'osnet_x0_5':
'https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i',
'osnet_x0_25':
'https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs',
'osnet_ibn_x1_0':
'https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l'
}
##########
# Basic layers
##########
class ConvLayer(nn.Module):
"""Convolution layer (conv + bn + relu)."""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
groups=1,
IN=False
):
super(ConvLayer, self).__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=False,
groups=groups
)
if IN:
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
else:
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class Conv1x1(nn.Module):
"""1x1 convolution + bn + relu."""
def __init__(self, in_channels, out_channels, stride=1, groups=1):
super(Conv1x1, self).__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
1,
stride=stride,
padding=0,
bias=False,
groups=groups
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class Conv1x1Linear(nn.Module):
"""1x1 convolution + bn (w/o non-linearity)."""
def __init__(self, in_channels, out_channels, stride=1):
super(Conv1x1Linear, self).__init__()
self.conv = nn.Conv2d(
in_channels, out_channels, 1, stride=stride, padding=0, bias=False
)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class Conv3x3(nn.Module):
"""3x3 convolution + bn + relu."""
def __init__(self, in_channels, out_channels, stride=1, groups=1):
super(Conv3x3, self).__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
3,
stride=stride,
padding=1,
bias=False,
groups=groups
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class LightConv3x3(nn.Module):
"""Lightweight 3x3 convolution.
1x1 (linear) + dw 3x3 (nonlinear).
"""
def __init__(self, in_channels, out_channels):
super(LightConv3x3, self).__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, 1, stride=1, padding=0, bias=False
)
self.conv2 = nn.Conv2d(
out_channels,
out_channels,
3,
stride=1,
padding=1,
bias=False,
groups=out_channels
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.bn(x)
x = self.relu(x)
return x
##########
# Building blocks for omni-scale feature learning
##########
class ChannelGate(nn.Module):
"""A mini-network that generates channel-wise gates conditioned on input tensor."""
def __init__(
self,
in_channels,
num_gates=None,
return_gates=False,
gate_activation='sigmoid',
reduction=16,
layer_norm=False
):
super(ChannelGate, self).__init__()
if num_gates is None:
num_gates = in_channels
self.return_gates = return_gates
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(
in_channels,
in_channels // reduction,
kernel_size=1,
bias=True,
padding=0
)
self.norm1 = None
if layer_norm:
self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(
in_channels // reduction,
num_gates,
kernel_size=1,
bias=True,
padding=0
)
if gate_activation == 'sigmoid':
self.gate_activation = nn.Sigmoid()
elif gate_activation == 'relu':
self.gate_activation = nn.ReLU(inplace=True)
elif gate_activation == 'linear':
self.gate_activation = None
else:
raise RuntimeError(
"Unknown gate activation: {}".format(gate_activation)
)
def forward(self, x):
input = x
x = self.global_avgpool(x)
x = self.fc1(x)
if self.norm1 is not None:
x = self.norm1(x)
x = self.relu(x)
x = self.fc2(x)
if self.gate_activation is not None:
x = self.gate_activation(x)
if self.return_gates:
return x
return input * x
class OSBlock(nn.Module):
"""Omni-scale feature learning block."""
def __init__(
self,
in_channels,
out_channels,
IN=False,
bottleneck_reduction=4,
**kwargs
):
super(OSBlock, self).__init__()
mid_channels = out_channels // bottleneck_reduction
self.conv1 = Conv1x1(in_channels, mid_channels)
self.conv2a = LightConv3x3(mid_channels, mid_channels)
self.conv2b = nn.Sequential(
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
)
self.conv2c = nn.Sequential(
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
)
self.conv2d = nn.Sequential(
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
)
self.gate = ChannelGate(mid_channels)
self.conv3 = Conv1x1Linear(mid_channels, out_channels)
self.downsample = None
if in_channels != out_channels:
self.downsample = Conv1x1Linear(in_channels, out_channels)
self.IN = None
if IN:
self.IN = nn.InstanceNorm2d(out_channels, affine=True)
def forward(self, x):
identity = x
x1 = self.conv1(x)
x2a = self.conv2a(x1)
x2b = self.conv2b(x1)
x2c = self.conv2c(x1)
x2d = self.conv2d(x1)
x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d)
x3 = self.conv3(x2)
if self.downsample is not None:
identity = self.downsample(identity)
out = x3 + identity
if self.IN is not None:
out = self.IN(out)
return F.relu(out)
##########
# Network architecture
##########
class OSNet(nn.Module):
"""Omni-Scale Network.
Reference:
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
- Zhou et al. Learning Generalisable Omni-Scale Representations
for Person Re-Identification. arXiv preprint, 2019.
"""
def __init__(
self,
blocks,
layers,
channels,
IN=False,
**kwargs
):
super(OSNet, self).__init__()
num_blocks = len(blocks)
assert num_blocks == len(layers)
assert num_blocks == len(channels) - 1
# convolutional backbone
self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN)
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
self.conv2 = self._make_layer(
blocks[0],
layers[0],
channels[0],
channels[1],
reduce_spatial_size=True,
IN=IN
)
self.conv3 = self._make_layer(
blocks[1],
layers[1],
channels[1],
channels[2],
reduce_spatial_size=True
)
self.conv4 = self._make_layer(
blocks[2],
layers[2],
channels[2],
channels[3],
reduce_spatial_size=False
)
self.conv5 = Conv1x1(channels[3], channels[3])
self._init_params()
def _make_layer(
self,
block,
layer,
in_channels,
out_channels,
reduce_spatial_size,
IN=False
):
layers = []
layers.append(block(in_channels, out_channels, IN=IN))
for i in range(1, layer):
layers.append(block(out_channels, out_channels, IN=IN))
if reduce_spatial_size:
layers.append(
nn.Sequential(
Conv1x1(out_channels, out_channels),
nn.AvgPool2d(2, stride=2)
)
)
return nn.Sequential(*layers)
def _init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu'
)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
return x
def init_pretrained_weights(model, key=''):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
import os
import errno
import gdown
from collections import OrderedDict
import warnings
import logging
logger = logging.getLogger(__name__)
def _get_torch_home():
ENV_TORCH_HOME = 'TORCH_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
torch_home = os.path.expanduser(
os.getenv(
ENV_TORCH_HOME,
os.path.join(
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
)
)
)
return torch_home
torch_home = _get_torch_home()
model_dir = os.path.join(torch_home, 'checkpoints')
try:
os.makedirs(model_dir)
except OSError as e:
if e.errno == errno.EEXIST:
# Directory already exists, ignore.
pass
else:
# Unexpected OSError, re-raise.
raise
filename = key + '_imagenet.pth'
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
gdown.download(model_urls[key], cached_file, quiet=False)
state_dict = torch.load(cached_file)
model_dict = model.state_dict()
new_state_dict = OrderedDict()
matched_layers, discarded_layers = [], []
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:] # discard module.
if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_layers.append(k)
else:
discarded_layers.append(k)
model_dict.update(new_state_dict)
model.load_state_dict(model_dict)
if len(matched_layers) == 0:
warnings.warn(
'The pretrained weights from "{}" cannot be loaded, '
'please check the key names manually '
'(** ignored and continue **)'.format(cached_file)
)
else:
logger.info(
'Successfully loaded imagenet pretrained weights from "{}"'.
format(cached_file)
)
if len(discarded_layers) > 0:
logger.info(
'** The following layers are discarded '
'due to unmatched keys or layer size: {}'.
format(discarded_layers)
)
@BACKBONE_REGISTRY.register()
def build_osnet_backbone(cfg):
"""
Create a OSNet instance from config.
Returns:
OSNet: a :class:`OSNet` instance
"""
# fmt: off
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
num_blocks_per_stage = [2, 2, 2]
num_channels_per_stage = [64, 256, 384, 512]
model = OSNet([OSBlock, OSBlock, OSBlock], num_blocks_per_stage, num_channels_per_stage, with_ibn)
pretrain_key = 'osnet_ibn_x1_0' if with_ibn else 'osnet_x1_0'
if pretrain:
init_pretrained_weights(model, pretrain_key)
return model

View File

@ -0,0 +1,401 @@
# encoding: utf-8
# based on:
# https://github.com/zhanghang1989/ResNeSt/blob/master/resnest/torch/resnest.py
"""ResNeSt models"""
import logging
import math
import torch
from torch import nn
from .build import BACKBONE_REGISTRY
from ...layers import SplAtConv2d, IBN, Non_local
_url_format = 'https://hangzh.s3.amazonaws.com/encoding/models/{}-{}.pth'
_model_sha256 = {name: checksum for checksum, name in [
('528c19ca', 'resnest50'),
('22405ba7', 'resnest101'),
('75117900', 'resnest200'),
('0cc87c48', 'resnest269'),
]}
def short_hash(name):
if name not in _model_sha256:
raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
return _model_sha256[name][:8]
model_urls = {name: _url_format.format(name, short_hash(name)) for
name in _model_sha256.keys()
}
class Bottleneck(nn.Module):
"""ResNet Bottleneck
"""
# pylint: disable=unused-argument
expansion = 4
def __init__(self, inplanes, planes, with_ibn=False, stride=1, downsample=None,
radix=1, cardinality=1, bottleneck_width=64,
avd=False, avd_first=False, dilation=1, is_first=False,
rectified_conv=False, rectify_avg=False,
norm_layer=None, dropblock_prob=0.0, last_gamma=False):
super(Bottleneck, self).__init__()
group_width = int(planes * (bottleneck_width / 64.)) * cardinality
self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
if with_ibn:
self.bn1 = IBN(group_width)
else:
self.bn1 = norm_layer(group_width)
self.dropblock_prob = dropblock_prob
self.radix = radix
self.avd = avd and (stride > 1 or is_first)
self.avd_first = avd_first
if self.avd:
self.avd_layer = nn.AvgPool2d(3, stride, padding=1)
stride = 1
if radix > 1:
self.conv2 = SplAtConv2d(
group_width, group_width, kernel_size=3,
stride=stride, padding=dilation,
dilation=dilation, groups=cardinality, bias=False,
radix=radix, rectify=rectified_conv,
rectify_avg=rectify_avg,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
elif rectified_conv:
from rfconv import RFConv2d
self.conv2 = RFConv2d(
group_width, group_width, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation,
groups=cardinality, bias=False,
average_mode=rectify_avg)
self.bn2 = norm_layer(group_width)
else:
self.conv2 = nn.Conv2d(
group_width, group_width, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation,
groups=cardinality, bias=False)
self.bn2 = norm_layer(group_width)
self.conv3 = nn.Conv2d(
group_width, planes * 4, kernel_size=1, bias=False)
self.bn3 = norm_layer(planes * 4)
if last_gamma:
from torch.nn.init import zeros_
zeros_(self.bn3.weight)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.dilation = dilation
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
if self.dropblock_prob > 0.0:
out = self.dropblock1(out)
out = self.relu(out)
if self.avd and self.avd_first:
out = self.avd_layer(out)
out = self.conv2(out)
if self.radix == 1:
out = self.bn2(out)
if self.dropblock_prob > 0.0:
out = self.dropblock2(out)
out = self.relu(out)
if self.avd and not self.avd_first:
out = self.avd_layer(out)
out = self.conv3(out)
out = self.bn3(out)
if self.dropblock_prob > 0.0:
out = self.dropblock3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNest(nn.Module):
"""ResNet Variants ResNest
Parameters
----------
block : Block
Class for the residual block. Options are BasicBlockV1, BottleneckV1.
layers : list of int
Numbers of layers in each block
classes : int, default 1000
Number of classification classes.
dilated : bool, default False
Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
typically used in Semantic Segmentation.
norm_layer : object
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
for Synchronized Cross-GPU BachNormalization).
Reference:
- He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
- Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
"""
# pylint: disable=unused-variable
def __init__(self, last_stride, with_ibn, with_nl, block, layers, non_layers, radix=1, groups=1,
bottleneck_width=64,
dilated=False, dilation=1,
deep_stem=False, stem_width=64, avg_down=False,
rectified_conv=False, rectify_avg=False,
avd=False, avd_first=False,
final_drop=0.0, dropblock_prob=0,
last_gamma=False, norm_layer=nn.BatchNorm2d):
self.cardinality = groups
self.bottleneck_width = bottleneck_width
# ResNet-D params
self.inplanes = stem_width * 2 if deep_stem else 64
self.avg_down = avg_down
self.last_gamma = last_gamma
# ResNeSt params
self.radix = radix
self.avd = avd
self.avd_first = avd_first
super().__init__()
self.rectified_conv = rectified_conv
self.rectify_avg = rectify_avg
if rectified_conv:
from rfconv import RFConv2d
conv_layer = RFConv2d
else:
conv_layer = nn.Conv2d
conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {}
if deep_stem:
self.conv1 = nn.Sequential(
conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs),
norm_layer(stem_width),
nn.ReLU(inplace=True),
conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
norm_layer(stem_width),
nn.ReLU(inplace=True),
conv_layer(stem_width, stem_width * 2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
)
else:
self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3,
bias=False, **conv_kwargs)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], with_ibn=with_ibn, norm_layer=norm_layer, is_first=False)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, with_ibn=with_ibn, norm_layer=norm_layer)
if dilated or dilation == 4:
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, with_ibn=with_ibn,
dilation=2, norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, with_ibn=with_ibn,
dilation=4, norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
elif dilation == 2:
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, with_ibn=with_ibn,
dilation=1, norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, with_ibn=with_ibn,
dilation=2, norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
else:
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, with_ibn=with_ibn,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride, with_ibn=with_ibn,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, norm_layer):
m.weight.data.fill_(1)
m.bias.data.zero_()
if with_nl:
self._build_nonlocal(layers, non_layers)
else:
self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False, dilation=1, norm_layer=None,
dropblock_prob=0.0, is_first=True):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
down_layers = []
if self.avg_down:
if dilation == 1:
down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride,
ceil_mode=True, count_include_pad=False))
else:
down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1,
ceil_mode=True, count_include_pad=False))
down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=1, bias=False))
else:
down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False))
down_layers.append(norm_layer(planes * block.expansion))
downsample = nn.Sequential(*down_layers)
layers = []
if planes == 512:
with_ibn = False
if dilation == 1 or dilation == 2:
layers.append(block(self.inplanes, planes, with_ibn, stride, downsample=downsample,
radix=self.radix, cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd, avd_first=self.avd_first,
dilation=1, is_first=is_first, rectified_conv=self.rectified_conv,
rectify_avg=self.rectify_avg,
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
last_gamma=self.last_gamma))
elif dilation == 4:
layers.append(block(self.inplanes, planes, with_ibn, stride, downsample=downsample,
radix=self.radix, cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd, avd_first=self.avd_first,
dilation=2, is_first=is_first, rectified_conv=self.rectified_conv,
rectify_avg=self.rectify_avg,
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
last_gamma=self.last_gamma))
else:
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, with_ibn,
radix=self.radix, cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd, avd_first=self.avd_first,
dilation=dilation, rectified_conv=self.rectified_conv,
rectify_avg=self.rectify_avg,
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
last_gamma=self.last_gamma))
return nn.Sequential(*layers)
def _build_nonlocal(self, layers, non_layers):
self.NL_1 = nn.ModuleList(
[Non_local(256) for _ in range(non_layers[0])])
self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
self.NL_2 = nn.ModuleList(
[Non_local(512) for _ in range(non_layers[1])])
self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
self.NL_3 = nn.ModuleList(
[Non_local(1024) for _ in range(non_layers[2])])
self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
self.NL_4 = nn.ModuleList(
[Non_local(2048) for _ in range(non_layers[3])])
self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
NL1_counter = 0
if len(self.NL_1_idx) == 0:
self.NL_1_idx = [-1]
for i in range(len(self.layer1)):
x = self.layer1[i](x)
if i == self.NL_1_idx[NL1_counter]:
_, C, H, W = x.shape
x = self.NL_1[NL1_counter](x)
NL1_counter += 1
# Layer 2
NL2_counter = 0
if len(self.NL_2_idx) == 0:
self.NL_2_idx = [-1]
for i in range(len(self.layer2)):
x = self.layer2[i](x)
if i == self.NL_2_idx[NL2_counter]:
_, C, H, W = x.shape
x = self.NL_2[NL2_counter](x)
NL2_counter += 1
# Layer 3
NL3_counter = 0
if len(self.NL_3_idx) == 0:
self.NL_3_idx = [-1]
for i in range(len(self.layer3)):
x = self.layer3[i](x)
if i == self.NL_3_idx[NL3_counter]:
_, C, H, W = x.shape
x = self.NL_3[NL3_counter](x)
NL3_counter += 1
# Layer 4
NL4_counter = 0
if len(self.NL_4_idx) == 0:
self.NL_4_idx = [-1]
for i in range(len(self.layer4)):
x = self.layer4[i](x)
if i == self.NL_4_idx[NL4_counter]:
_, C, H, W = x.shape
x = self.NL_4[NL4_counter](x)
NL4_counter += 1
return x
@BACKBONE_REGISTRY.register()
def build_resnest_backbone(cfg):
"""
Create a ResNest instance from config.
Returns:
ResNet: a :class:`ResNet` instance.
"""
# fmt: off
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
with_se = cfg.MODEL.BACKBONE.WITH_SE
with_nl = cfg.MODEL.BACKBONE.WITH_NL
depth = cfg.MODEL.BACKBONE.DEPTH
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 200: [3, 24, 36, 3], 269: [3, 30, 48, 8]}[depth]
nl_layers_per_stage = {50: [0, 2, 3, 0], 101: [0, 2, 3, 0]}[depth]
stem_width = {50: 32, 101: 64, 200: 64, 269: 64}[depth]
model = ResNest(last_stride, with_ibn, with_nl, Bottleneck, num_blocks_per_stage, nl_layers_per_stage,
radix=2, groups=1, bottleneck_width=64,
deep_stem=True, stem_width=stem_width, avg_down=True,
avd=True, avd_first=False)
if pretrain:
# if not with_ibn:
# original resnet
state_dict = torch.hub.load_state_dict_from_url(
model_urls['resnest' + str(depth)], progress=True, check_hash=True)
# else:
# raise KeyError('Not implementation ibn in resnest')
# # ibn resnet
# state_dict = torch.load(pretrain_path)['state_dict']
# # remove module in name
# new_state_dict = {}
# for k in state_dict:
# new_k = '.'.join(k.split('.')[1:])
# if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
# new_state_dict[new_k] = state_dict[k]
# state_dict = new_state_dict
res = model.load_state_dict(state_dict, strict=False)
logger = logging.getLogger(__name__)
logger.info('missing keys is {}'.format(res.missing_keys))
logger.info('unexpected keys is {}'.format(res.unexpected_keys))
return model

View File

@ -0,0 +1,198 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: liaoxingyu5@jd.com
"""
# based on:
# https://github.com/XingangPan/IBN-Net/blob/master/models/imagenet/resnext_ibn_a.py
import math
import logging
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import torch
from ...layers import IBN
from .build import BACKBONE_REGISTRY
class Bottleneck(nn.Module):
"""
RexNeXt bottleneck type C
"""
expansion = 4
def __init__(self, inplanes, planes, with_ibn, baseWidth, cardinality, stride=1, downsample=None):
""" Constructor
Args:
inplanes: input channel dimensionality
planes: output channel dimensionality
baseWidth: base width.
cardinality: num of convolution groups.
stride: conv stride. Replaces pooling layer.
"""
super(Bottleneck, self).__init__()
D = int(math.floor(planes * (baseWidth / 64)))
C = cardinality
self.conv1 = nn.Conv2d(inplanes, D * C, kernel_size=1, stride=1, padding=0, bias=False)
if with_ibn:
self.bn1 = IBN(D * C)
else:
self.bn1 = nn.BatchNorm2d(D * C)
self.conv2 = nn.Conv2d(D * C, D * C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False)
self.bn2 = nn.BatchNorm2d(D * C)
self.conv3 = nn.Conv2d(D * C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNeXt(nn.Module):
"""
ResNext optimized for the ImageNet dataset, as specified in
https://arxiv.org/pdf/1611.05431.pdf
"""
def __init__(self, last_stride, with_ibn, block, layers, baseWidth=4, cardinality=32):
""" Constructor
Args:
baseWidth: baseWidth for ResNeXt.
cardinality: number of convolution groups.
layers: config of layers, e.g., [3, 4, 6, 3]
num_classes: number of classes
"""
super(ResNeXt, self).__init__()
self.cardinality = cardinality
self.baseWidth = baseWidth
self.inplanes = 64
self.output_size = 64
self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], with_ibn=with_ibn)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, with_ibn=with_ibn)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, with_ibn=with_ibn)
self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride, with_ibn=with_ibn)
self.random_init()
def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False):
""" Stack n bottleneck modules where n is inferred from the depth of the network.
Args:
block: block type used to construct ResNext
planes: number of output channels (need to multiply by block.expansion)
blocks: number of blocks to be built
stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
Returns: a Module consisting of n sequential bottlenecks.
"""
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
if planes == 512:
with_ibn = False
layers.append(block(self.inplanes, planes, with_ibn, self.baseWidth, self.cardinality, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, with_ibn, self.baseWidth, self.cardinality, 1, None))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def random_init(self):
self.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64)))
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.InstanceNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
@BACKBONE_REGISTRY.register()
def build_resnext_backbone(cfg):
"""
Create a ResNeXt instance from config.
Returns:
ResNeXt: a :class:`ResNeXt` instance.
"""
# fmt: off
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
with_se = cfg.MODEL.BACKBONE.WITH_SE
with_nl = cfg.MODEL.BACKBONE.WITH_NL
depth = cfg.MODEL.BACKBONE.DEPTH
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], }[depth]
nl_layers_per_stage = {50: [0, 2, 3, 0], 101: [0, 2, 3, 0]}[depth]
model = ResNeXt(last_stride, with_ibn, Bottleneck, num_blocks_per_stage)
if pretrain:
# if not with_ibn:
# original resnet
# state_dict = model_zoo.load_url(model_urls[depth])
# else:
# ibn resnet
state_dict = torch.load(pretrain_path)['state_dict']
# remove module in name
new_state_dict = {}
for k in state_dict:
new_k = '.'.join(k.split('.')[1:])
if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
new_state_dict[new_k] = state_dict[k]
state_dict = new_state_dict
res = model.load_state_dict(state_dict, strict=False)
logger = logging.getLogger(__name__)
logger.info('missing keys is {}'.format(res.missing_keys))
logger.info('unexpected keys is {}'.format(res.unexpected_keys))
return model