add link to shufflenet's imagenet weights
parent
47b883d0d6
commit
c05f96e31b
|
@ -39,7 +39,7 @@ __model_factory = {
|
|||
# lightweight models
|
||||
'nasnsetmobile': nasnetamobile,
|
||||
'mobilenetv2': MobileNetV2,
|
||||
'shufflenet': ShuffleNet,
|
||||
'shufflenet': shufflenet,
|
||||
'squeezenet1_0': squeezenet1_0,
|
||||
'squeezenet1_0_fc512': squeezenet1_0_fc512,
|
||||
'squeezenet1_1': squeezenet1_1,
|
||||
|
|
|
@ -5,9 +5,16 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import torchvision
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
|
||||
__all__ = ['ShuffleNet']
|
||||
__all__ = ['shufflenet']
|
||||
|
||||
|
||||
model_urls = {
|
||||
# training epoch = 90, top1 = 61.8
|
||||
'imagenet': 'http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/imagenet-pretrained/shufflenet-bee1b265.pth.tar',
|
||||
}
|
||||
|
||||
|
||||
class ChannelShuffle(nn.Module):
|
||||
|
@ -132,4 +139,24 @@ class ShuffleNet(nn.Module):
|
|||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, x
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
|
||||
def init_pretrained_weights(model, model_url):
|
||||
"""
|
||||
Initialize model with pretrained weights.
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
pretrain_dict = model_zoo.load_url(model_url)
|
||||
model_dict = model.state_dict()
|
||||
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
|
||||
model_dict.update(pretrain_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
print("Initialized model with pretrained weights from {}".format(model_url))
|
||||
|
||||
|
||||
def shufflenet(num_classes, loss, pretrained='imagenet', **kwargs):
|
||||
model = ShuffleNet(num_classes, loss, **kwargs)
|
||||
if pretrained == 'imagenet':
|
||||
init_pretrained_weights(model, model_urls['imagenet'])
|
||||
return model
|
Loading…
Reference in New Issue