add link to shufflenet's imagenet weights

pull/119/head
KaiyangZhou 2018-11-26 22:07:20 +00:00
parent 47b883d0d6
commit c05f96e31b
2 changed files with 30 additions and 3 deletions

View File

@ -39,7 +39,7 @@ __model_factory = {
# lightweight models
'nasnsetmobile': nasnetamobile,
'mobilenetv2': MobileNetV2,
'shufflenet': ShuffleNet,
'shufflenet': shufflenet,
'squeezenet1_0': squeezenet1_0,
'squeezenet1_0_fc512': squeezenet1_0_fc512,
'squeezenet1_1': squeezenet1_1,

View File

@ -5,9 +5,16 @@ import torch
from torch import nn
from torch.nn import functional as F
import torchvision
import torch.utils.model_zoo as model_zoo
__all__ = ['ShuffleNet']
__all__ = ['shufflenet']
model_urls = {
# training epoch = 90, top1 = 61.8
'imagenet': 'http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/imagenet-pretrained/shufflenet-bee1b265.pth.tar',
}
class ChannelShuffle(nn.Module):
@ -132,4 +139,24 @@ class ShuffleNet(nn.Module):
elif self.loss == {'xent', 'htri'}:
return y, x
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
raise KeyError("Unsupported loss: {}".format(self.loss))
def init_pretrained_weights(model, model_url):
"""
Initialize model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
print("Initialized model with pretrained weights from {}".format(model_url))
def shufflenet(num_classes, loss, pretrained='imagenet', **kwargs):
model = ShuffleNet(num_classes, loss, **kwargs)
if pretrained == 'imagenet':
init_pretrained_weights(model, model_urls['imagenet'])
return model