mirror of https://github.com/JDAI-CV/fast-reid.git
204 lines
6.9 KiB
Python
204 lines
6.9 KiB
Python
"""
|
|
Author: Guan'an Wang
|
|
Contact: guan.wang0706@gmail.com
|
|
"""
|
|
|
|
import torch
|
|
from torch import nn
|
|
from collections import OrderedDict
|
|
import logging
|
|
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
|
|
|
|
from fastreid.layers import get_norm
|
|
from fastreid.modeling.backbones import BACKBONE_REGISTRY
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ShuffleV2Block(nn.Module):
|
|
"""
|
|
Reference:
|
|
https://github.com/megvii-model/ShuffleNet-Series/tree/master/ShuffleNetV2
|
|
"""
|
|
|
|
def __init__(self, bn_norm, inp, oup, mid_channels, *, ksize, stride):
|
|
super(ShuffleV2Block, self).__init__()
|
|
self.stride = stride
|
|
assert stride in [1, 2]
|
|
|
|
self.mid_channels = mid_channels
|
|
self.ksize = ksize
|
|
pad = ksize // 2
|
|
self.pad = pad
|
|
self.inp = inp
|
|
|
|
outputs = oup - inp
|
|
|
|
branch_main = [
|
|
# pw
|
|
nn.Conv2d(inp, mid_channels, 1, 1, 0, bias=False),
|
|
get_norm(bn_norm, mid_channels),
|
|
nn.ReLU(inplace=True),
|
|
# dw
|
|
nn.Conv2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=False),
|
|
get_norm(bn_norm, mid_channels),
|
|
# pw-linear
|
|
nn.Conv2d(mid_channels, outputs, 1, 1, 0, bias=False),
|
|
get_norm(bn_norm, outputs),
|
|
nn.ReLU(inplace=True),
|
|
]
|
|
self.branch_main = nn.Sequential(*branch_main)
|
|
|
|
if stride == 2:
|
|
branch_proj = [
|
|
# dw
|
|
nn.Conv2d(inp, inp, ksize, stride, pad, groups=inp, bias=False),
|
|
get_norm(bn_norm, inp),
|
|
# pw-linear
|
|
nn.Conv2d(inp, inp, 1, 1, 0, bias=False),
|
|
get_norm(bn_norm, inp),
|
|
nn.ReLU(inplace=True),
|
|
]
|
|
self.branch_proj = nn.Sequential(*branch_proj)
|
|
else:
|
|
self.branch_proj = None
|
|
|
|
def forward(self, old_x):
|
|
if self.stride == 1:
|
|
x_proj, x = self.channel_shuffle(old_x)
|
|
return torch.cat((x_proj, self.branch_main(x)), 1)
|
|
elif self.stride == 2:
|
|
x_proj = old_x
|
|
x = old_x
|
|
return torch.cat((self.branch_proj(x_proj), self.branch_main(x)), 1)
|
|
|
|
def channel_shuffle(self, x):
|
|
batchsize, num_channels, height, width = x.data.size()
|
|
assert (num_channels % 4 == 0)
|
|
x = x.reshape(batchsize * num_channels // 2, 2, height * width)
|
|
x = x.permute(1, 0, 2)
|
|
x = x.reshape(2, -1, num_channels // 2, height, width)
|
|
return x[0], x[1]
|
|
|
|
|
|
class ShuffleNetV2(nn.Module):
|
|
"""
|
|
Reference:
|
|
https://github.com/megvii-model/ShuffleNet-Series/tree/master/ShuffleNetV2
|
|
"""
|
|
|
|
def __init__(self, bn_norm, model_size='1.5x'):
|
|
super(ShuffleNetV2, self).__init__()
|
|
|
|
self.stage_repeats = [4, 8, 4]
|
|
self.model_size = model_size
|
|
if model_size == '0.5x':
|
|
self.stage_out_channels = [-1, 24, 48, 96, 192, 1024]
|
|
elif model_size == '1.0x':
|
|
self.stage_out_channels = [-1, 24, 116, 232, 464, 1024]
|
|
elif model_size == '1.5x':
|
|
self.stage_out_channels = [-1, 24, 176, 352, 704, 1024]
|
|
elif model_size == '2.0x':
|
|
self.stage_out_channels = [-1, 24, 244, 488, 976, 2048]
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
# building first layer
|
|
input_channel = self.stage_out_channels[1]
|
|
self.first_conv = nn.Sequential(
|
|
nn.Conv2d(3, input_channel, 3, 2, 1, bias=False),
|
|
get_norm(bn_norm, input_channel),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
|
|
self.features = []
|
|
for idxstage in range(len(self.stage_repeats)):
|
|
numrepeat = self.stage_repeats[idxstage]
|
|
output_channel = self.stage_out_channels[idxstage + 2]
|
|
|
|
for i in range(numrepeat):
|
|
if i == 0:
|
|
self.features.append(ShuffleV2Block(bn_norm, input_channel, output_channel,
|
|
mid_channels=output_channel // 2, ksize=3, stride=2))
|
|
else:
|
|
self.features.append(ShuffleV2Block(bn_norm, input_channel // 2, output_channel,
|
|
mid_channels=output_channel // 2, ksize=3, stride=1))
|
|
|
|
input_channel = output_channel
|
|
|
|
self.features = nn.Sequential(*self.features)
|
|
|
|
self.conv_last = nn.Sequential(
|
|
nn.Conv2d(input_channel, self.stage_out_channels[-1], 1, 1, 0, bias=False),
|
|
get_norm(bn_norm, self.stage_out_channels[-1]),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
self._initialize_weights()
|
|
|
|
def forward(self, x):
|
|
x = self.first_conv(x)
|
|
x = self.maxpool(x)
|
|
x = self.features(x)
|
|
x = self.conv_last(x)
|
|
|
|
return x
|
|
|
|
def _initialize_weights(self):
|
|
for name, m in self.named_modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
if 'first' in name:
|
|
nn.init.normal_(m.weight, 0, 0.01)
|
|
else:
|
|
nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.constant_(m.weight, 1)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0.0001)
|
|
nn.init.constant_(m.running_mean, 0)
|
|
elif isinstance(m, nn.BatchNorm1d):
|
|
nn.init.constant_(m.weight, 1)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0.0001)
|
|
nn.init.constant_(m.running_mean, 0)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.normal_(m.weight, 0, 0.01)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
@BACKBONE_REGISTRY.register()
|
|
def build_shufflenetv2_backbone(cfg):
|
|
# fmt: off
|
|
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
|
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
|
|
bn_norm = cfg.MODEL.BACKBONE.NORM
|
|
model_size = cfg.MODEL.BACKBONE.DEPTH
|
|
# fmt: on
|
|
|
|
model = ShuffleNetV2(bn_norm, model_size=model_size)
|
|
|
|
if pretrain:
|
|
new_state_dict = OrderedDict()
|
|
state_dict = torch.load(pretrain_path)["state_dict"]
|
|
for k, v in state_dict.items():
|
|
if k[:7] == 'module.':
|
|
k = k[7:]
|
|
new_state_dict[k] = v
|
|
|
|
incompatible = model.load_state_dict(new_state_dict, strict=False)
|
|
if incompatible.missing_keys:
|
|
logger.info(
|
|
get_missing_parameters_message(incompatible.missing_keys)
|
|
)
|
|
if incompatible.unexpected_keys:
|
|
logger.info(
|
|
get_unexpected_parameters_message(incompatible.unexpected_keys)
|
|
)
|
|
|
|
return model
|