172 lines
7.0 KiB
Python
172 lines
7.0 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from .backbones.resnet import ResNet, BasicBlock, Bottleneck
|
|
from .backbones.senet import SENet, SEResNetBottleneck, SEBottleneck, SEResNeXtBottleneck
|
|
|
|
|
|
def weights_init_kaiming(m):
|
|
classname = m.__class__.__name__
|
|
if classname.find('Linear') != -1:
|
|
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
|
|
nn.init.constant_(m.bias, 0.0)
|
|
elif classname.find('Conv') != -1:
|
|
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0.0)
|
|
elif classname.find('BatchNorm') != -1:
|
|
if m.affine:
|
|
nn.init.constant_(m.weight, 1.0)
|
|
nn.init.constant_(m.bias, 0.0)
|
|
|
|
|
|
def weights_init_classifier(m):
|
|
classname = m.__class__.__name__
|
|
if classname.find('Linear') != -1:
|
|
nn.init.normal_(m.weight, std=0.001)
|
|
if m.bias:
|
|
nn.init.constant_(m.bias, 0.0)
|
|
|
|
|
|
class Baseline(nn.Module):
|
|
in_planes = 2048
|
|
|
|
def __init__(self, num_classes, last_stride, model_path, neck, neck_feat, model_name):
|
|
super(Baseline, self).__init__()
|
|
if model_name == 'resnet18':
|
|
self.base = ResNet(last_stride=last_stride,
|
|
block=BasicBlock,
|
|
layers=[2, 2, 2, 2])
|
|
elif model_name == 'resnet34':
|
|
self.base = ResNet(last_stride=last_stride,
|
|
block=BasicBlock,
|
|
layers=[3, 4, 6, 3])
|
|
elif model_name == 'resnet50':
|
|
self.base = ResNet(last_stride=last_stride,
|
|
block=Bottleneck,
|
|
layers=[3, 4, 6, 3])
|
|
elif model_name == 'resnet101':
|
|
self.base = ResNet(last_stride=last_stride,
|
|
block=Bottleneck,
|
|
layers=[3, 4, 23, 3])
|
|
elif model_name == 'resnet152':
|
|
self.base = ResNet(last_stride=last_stride,
|
|
block=Bottleneck,
|
|
layers=[3, 8, 36, 3])
|
|
|
|
elif model_name == 'se_resnet50':
|
|
self.base = SENet(block=SEResNetBottleneck,
|
|
layers=[3, 4, 6, 3],
|
|
groups=1,
|
|
reduction=16,
|
|
dropout_p=None,
|
|
inplanes=64,
|
|
input_3x3=False,
|
|
downsample_kernel_size=1,
|
|
downsample_padding=0,
|
|
last_stride=last_stride)
|
|
elif model_name == 'se_resnet101':
|
|
self.base = SENet(block=SEResNetBottleneck,
|
|
layers=[3, 4, 23, 3],
|
|
groups=1,
|
|
reduction=16,
|
|
dropout_p=None,
|
|
inplanes=64,
|
|
input_3x3=False,
|
|
downsample_kernel_size=1,
|
|
downsample_padding=0,
|
|
last_stride=last_stride)
|
|
elif model_name == 'se_resnet152':
|
|
self.base = SENet(block=SEResNetBottleneck,
|
|
layers=[3, 4, 36, 3],
|
|
groups=1,
|
|
reduction=16,
|
|
dropout_p=None,
|
|
inplanes=64,
|
|
input_3x3=False,
|
|
downsample_kernel_size=1,
|
|
downsample_padding=0,
|
|
last_stride=last_stride)
|
|
elif model_name == 'se_resnext50':
|
|
self.base = SENet(block=SEResNeXtBottleneck,
|
|
layers=[3, 4, 6, 3],
|
|
groups=32,
|
|
reduction=16,
|
|
dropout_p=None,
|
|
inplanes=64,
|
|
input_3x3=False,
|
|
downsample_kernel_size=1,
|
|
downsample_padding=0,
|
|
last_stride=last_stride)
|
|
elif model_name == 'se_resnext101':
|
|
self.base = SENet(block=SEResNeXtBottleneck,
|
|
layers=[3, 4, 23, 3],
|
|
groups=32,
|
|
reduction=16,
|
|
dropout_p=None,
|
|
inplanes=64,
|
|
input_3x3=False,
|
|
downsample_kernel_size=1,
|
|
downsample_padding=0,
|
|
last_stride=last_stride)
|
|
elif model_name == 'senet154':
|
|
self.base = SENet(block=SEBottleneck,
|
|
layers=[3, 8, 36, 3],
|
|
groups=64,
|
|
reduction=16,
|
|
dropout_p=0.2,
|
|
last_stride=last_stride)
|
|
|
|
|
|
self.base.load_param(model_path)
|
|
self.gap = nn.AdaptiveAvgPool2d(1)
|
|
# self.gap = nn.AdaptiveMaxPool2d(1)
|
|
self.num_classes = num_classes
|
|
self.neck = neck
|
|
self.neck_feat = neck_feat
|
|
|
|
if self.neck == 'no':
|
|
self.classifier = nn.Linear(self.in_planes, self.num_classes)
|
|
# self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) # new add by luo
|
|
# self.classifier.apply(weights_init_classifier) # new add by luo
|
|
elif self.neck == 'bnneck':
|
|
self.bottleneck = nn.BatchNorm1d(self.in_planes)
|
|
self.bottleneck.bias.requires_grad_(False) # no shift
|
|
self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
|
|
|
|
self.bottleneck.apply(weights_init_kaiming)
|
|
self.classifier.apply(weights_init_classifier)
|
|
|
|
def forward(self, x):
|
|
|
|
global_feat = self.gap(self.base(x)) # (b, 2048, 1, 1)
|
|
global_feat = global_feat.view(global_feat.shape[0], -1) # flatten to (bs, 2048)
|
|
|
|
if self.neck == 'no':
|
|
feat = global_feat
|
|
elif self.neck == 'bnneck':
|
|
feat = self.bottleneck(global_feat) # normalize for angular softmax
|
|
|
|
if self.training:
|
|
cls_score = self.classifier(feat)
|
|
return cls_score, global_feat # global feature for triplet loss
|
|
else:
|
|
if self.neck_feat == 'after':
|
|
# print("Test with feature after BN")
|
|
return feat
|
|
else:
|
|
# print("Test with feature before BN")
|
|
return global_feat
|
|
|
|
def load_param(self, trained_path):
|
|
param_dict = torch.load(trained_path)
|
|
for i in param_dict:
|
|
if 'classifier' in i:
|
|
continue
|
|
self.state_dict()[i].copy_(param_dict[i]) |