From c05f96e31bc7033130a663d8055fdfe97ffd04db Mon Sep 17 00:00:00 2001 From: KaiyangZhou Date: Mon, 26 Nov 2018 22:07:20 +0000 Subject: [PATCH] add link to shufflenet's imagenet weights --- torchreid/models/__init__.py | 2 +- torchreid/models/shufflenet.py | 31 +++++++++++++++++++++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/torchreid/models/__init__.py b/torchreid/models/__init__.py index 751567c..9579ffb 100644 --- a/torchreid/models/__init__.py +++ b/torchreid/models/__init__.py @@ -39,7 +39,7 @@ __model_factory = { # lightweight models 'nasnsetmobile': nasnetamobile, 'mobilenetv2': MobileNetV2, - 'shufflenet': ShuffleNet, + 'shufflenet': shufflenet, 'squeezenet1_0': squeezenet1_0, 'squeezenet1_0_fc512': squeezenet1_0_fc512, 'squeezenet1_1': squeezenet1_1, diff --git a/torchreid/models/shufflenet.py b/torchreid/models/shufflenet.py index 3d73d38..a989726 100644 --- a/torchreid/models/shufflenet.py +++ b/torchreid/models/shufflenet.py @@ -5,9 +5,16 @@ import torch from torch import nn from torch.nn import functional as F import torchvision +import torch.utils.model_zoo as model_zoo -__all__ = ['ShuffleNet'] +__all__ = ['shufflenet'] + + +model_urls = { + # training epoch = 90, top1 = 61.8 + 'imagenet': 'http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/imagenet-pretrained/shufflenet-bee1b265.pth.tar', +} class ChannelShuffle(nn.Module): @@ -132,4 +139,24 @@ class ShuffleNet(nn.Module): elif self.loss == {'xent', 'htri'}: return y, x else: - raise KeyError("Unsupported loss: {}".format(self.loss)) \ No newline at end of file + raise KeyError("Unsupported loss: {}".format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """ + Initialize model with pretrained weights. + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + print("Initialized model with pretrained weights from {}".format(model_url)) + + +def shufflenet(num_classes, loss, pretrained='imagenet', **kwargs): + model = ShuffleNet(num_classes, loss, **kwargs) + if pretrained == 'imagenet': + init_pretrained_weights(model, model_urls['imagenet']) + return model \ No newline at end of file