fast-reid/fastreid/modeling/backbones/resnest.py

339 lines
14 KiB
Python

# encoding: utf-8
# based on:
# https://github.com/zhanghang1989/ResNeSt/blob/master/resnest/torch/resnest.py
"""ResNeSt models"""
import torch
from torch import nn
import math
import logging
from .resnet import ResNet, Bottleneck
from .build import BACKBONE_REGISTRY
_url_format = 'https://hangzh.s3.amazonaws.com/encoding/models/{}-{}.pth'
_model_sha256 = {name: checksum for checksum, name in [
('528c19ca', '50'),
('22405ba7', '101'),
('75117900', '200'),
('0cc87c48', '269'),
]}
def short_hash(name):
if name not in _model_sha256:
raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
return _model_sha256[name][:8]
resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for
name in _model_sha256.keys()
}
class Bottleneck(nn.Module):
"""ResNet Bottleneck
"""
# pylint: disable=unused-argument
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
radix=1, cardinality=1, bottleneck_width=64,
avd=False, avd_first=False, dilation=1, is_first=False,
rectified_conv=False, rectify_avg=False,
norm_layer=None, dropblock_prob=0.0, last_gamma=False):
super(Bottleneck, self).__init__()
group_width = int(planes * (bottleneck_width / 64.)) * cardinality
self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
self.bn1 = norm_layer(group_width)
self.dropblock_prob = dropblock_prob
self.radix = radix
self.avd = avd and (stride > 1 or is_first)
self.avd_first = avd_first
if self.avd:
self.avd_layer = nn.AvgPool2d(3, stride, padding=1)
stride = 1
if radix > 1:
self.conv2 = SplAtConv2d(
group_width, group_width, kernel_size=3,
stride=stride, padding=dilation,
dilation=dilation, groups=cardinality, bias=False,
radix=radix, rectify=rectified_conv,
rectify_avg=rectify_avg,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
elif rectified_conv:
from rfconv import RFConv2d
self.conv2 = RFConv2d(
group_width, group_width, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation,
groups=cardinality, bias=False,
average_mode=rectify_avg)
self.bn2 = norm_layer(group_width)
else:
self.conv2 = nn.Conv2d(
group_width, group_width, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation,
groups=cardinality, bias=False)
self.bn2 = norm_layer(group_width)
self.conv3 = nn.Conv2d(
group_width, planes * 4, kernel_size=1, bias=False)
self.bn3 = norm_layer(planes * 4)
if last_gamma:
from torch.nn.init import zeros_
zeros_(self.bn3.weight)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.dilation = dilation
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
if self.dropblock_prob > 0.0:
out = self.dropblock1(out)
out = self.relu(out)
if self.avd and self.avd_first:
out = self.avd_layer(out)
out = self.conv2(out)
if self.radix == 1:
out = self.bn2(out)
if self.dropblock_prob > 0.0:
out = self.dropblock2(out)
out = self.relu(out)
if self.avd and not self.avd_first:
out = self.avd_layer(out)
out = self.conv3(out)
out = self.bn3(out)
if self.dropblock_prob > 0.0:
out = self.dropblock3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNest(nn.Module):
"""ResNet Variants ResNest
Parameters
----------
block : Block
Class for the residual block. Options are BasicBlockV1, BottleneckV1.
layers : list of int
Numbers of layers in each block
classes : int, default 1000
Number of classification classes.
dilated : bool, default False
Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
typically used in Semantic Segmentation.
norm_layer : object
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
for Synchronized Cross-GPU BachNormalization).
Reference:
- He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
- Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
"""
# pylint: disable=unused-variable
def __init__(self, block, layers, radix=1, groups=1, bottleneck_width=64,
num_classes=1000, dilated=False, dilation=1,
deep_stem=False, stem_width=64, avg_down=False,
rectified_conv=False, rectify_avg=False,
avd=False, avd_first=False,
final_drop=0.0, dropblock_prob=0,
last_gamma=False, norm_layer=nn.BatchNorm2d):
self.cardinality = groups
self.bottleneck_width = bottleneck_width
# ResNet-D params
self.inplanes = stem_width * 2 if deep_stem else 64
self.avg_down = avg_down
self.last_gamma = last_gamma
# ResNeSt params
self.radix = radix
self.avd = avd
self.avd_first = avd_first
super(ResNet, self).__init__()
self.rectified_conv = rectified_conv
self.rectify_avg = rectify_avg
if rectified_conv:
from rfconv import RFConv2d
conv_layer = RFConv2d
else:
conv_layer = nn.Conv2d
conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {}
if deep_stem:
self.conv1 = nn.Sequential(
conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs),
norm_layer(stem_width),
nn.ReLU(inplace=True),
conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
norm_layer(stem_width),
nn.ReLU(inplace=True),
conv_layer(stem_width, stem_width * 2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
)
else:
self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3,
bias=False, **conv_kwargs)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
if dilated or dilation == 4:
self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
dilation=2, norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
dilation=4, norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
elif dilation == 2:
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilation=1, norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
dilation=2, norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
else:
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, norm_layer):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None,
dropblock_prob=0.0, is_first=True):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
down_layers = []
if self.avg_down:
if dilation == 1:
down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride,
ceil_mode=True, count_include_pad=False))
else:
down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1,
ceil_mode=True, count_include_pad=False))
down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=1, bias=False))
else:
down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False))
down_layers.append(norm_layer(planes * block.expansion))
downsample = nn.Sequential(*down_layers)
layers = []
if dilation == 1 or dilation == 2:
layers.append(block(self.inplanes, planes, stride, downsample=downsample,
radix=self.radix, cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd, avd_first=self.avd_first,
dilation=1, is_first=is_first, rectified_conv=self.rectified_conv,
rectify_avg=self.rectify_avg,
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
last_gamma=self.last_gamma))
elif dilation == 4:
layers.append(block(self.inplanes, planes, stride, downsample=downsample,
radix=self.radix, cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd, avd_first=self.avd_first,
dilation=2, is_first=is_first, rectified_conv=self.rectified_conv,
rectify_avg=self.rectify_avg,
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
last_gamma=self.last_gamma))
else:
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes,
radix=self.radix, cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd, avd_first=self.avd_first,
dilation=dilation, rectified_conv=self.rectified_conv,
rectify_avg=self.rectify_avg,
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
last_gamma=self.last_gamma))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
@BACKBONE_REGISTRY.register()
def build_resnest_backbone(cfg):
"""
Create a ResNest instance from config.
Returns:
ResNet: a :class:`ResNet` instance.
"""
# fmt: off
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
with_se = cfg.MODEL.BACKBONE.WITH_SE
depth = cfg.MODEL.BACKBONE.DEPTH
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 200: [3, 24, 36, 3], 269: [3, 30, 48, 8]}[depth]
# model = ResNet(last_stride, with_ibn, with_se, Bottleneck, num_blocks_per_stage)
model = ResNest(Bottleneck, [3, 4, 6, 3],
radix=2, groups=1, bottleneck_width=64,
deep_stem=True, stem_width=32, avg_down=True,
avd=True, avd_first=False)
if pretrain:
if not with_ibn:
# original resnet
state_dict = torch.hub.load_state_dict_from_url(
resnest_model_urls[depth], progress=True, check_hash=True)
else:
raise KeyError('Not implementation ibn in resnest')
# # ibn resnet
# state_dict = torch.load(pretrain_path)['state_dict']
# # remove module in name
# new_state_dict = {}
# for k in state_dict:
# new_k = '.'.join(k.split('.')[1:])
# if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
# new_state_dict[new_k] = state_dict[k]
# state_dict = new_state_dict
res = model.load_state_dict(state_dict, strict=False)
logger = logging.getLogger(__name__)
logger.info('missing keys is {}'.format(res.missing_keys))
logger.info('unexpected keys is {}'.format(res.unexpected_keys))
return model