fast-reid/models/baseline_model.py

94 lines
2.8 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: xyliao1993@qq.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import itertools
import torch.nn.functional as F
from torch import nn
from .resnet import ResNet
def weights_init_kaiming(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
nn.init.constant_(m.bias, 0.0)
elif classname.find('Conv') != -1:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif classname.find('BatchNorm') != -1:
if m.affine:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.constant_(m.bias, 0.0)
def weights_init_classifier(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.normal_(m.weight, std=0.001)
nn.init.constant_(m.bias, 0.0)
class ResNetBuilder(nn.Module):
in_planes = 2048
def __init__(self, num_classes=None, last_stride=1, model_path='/DATA/model_zoo/resnet50-19c8e357.pth'):
super().__init__()
self.base = ResNet(last_stride)
self.base.load_param(model_path)
self.num_classes = num_classes
if num_classes is not None:
self.bottleneck = nn.Sequential(
nn.Linear(self.in_planes, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.1),
nn.Dropout(p=0.5)
)
self.bottleneck.apply(weights_init_kaiming)
self.classifier = nn.Linear(512, self.num_classes)
self.classifier.apply(weights_init_classifier)
def forward(self, x):
global_feat = self.base(x)
global_feat = F.avg_pool2d(global_feat, global_feat.shape[2:]) # (b, 2048, 1, 1)
global_feat = global_feat.view(global_feat.shape[0], -1)
if self.training and self.num_classes is not None:
feat = self.bottleneck(global_feat)
cls_score = self.classifier(feat)
return cls_score, global_feat
else:
return global_feat
def get_optim_policy(self):
base_param_group = self.base.parameters()
if self.num_classes is not None:
add_param_group = itertools.chain(self.bottleneck.parameters(), self.classifier.parameters())
return [
{'params': base_param_group},
{'params': add_param_group}
]
else:
return [
{'params': base_param_group}
]
if __name__ == '__main__':
net = ResNetBuilder(None)
net.cuda()
import torch as th
x = th.ones(2, 3, 256, 128).cuda()
y = net(x)
from IPython import embed
embed()