deep-person-reid/torchreid/models/__init__.py

61 lines
1.6 KiB
Python
Raw Normal View History

2018-07-02 17:33:10 +08:00
from __future__ import absolute_import
from .resnet import *
2018-10-27 06:05:40 +08:00
from .resnetmid import *
2018-10-27 23:52:05 +08:00
from .resnext import *
2018-10-27 06:54:13 +08:00
from .senet import *
2018-10-27 18:57:48 +08:00
from .densenet import *
2018-10-28 00:46:48 +08:00
from .inceptionresnetv2 import *
from .inceptionv4 import *
from .xception import *
2018-10-28 00:59:10 +08:00
2018-10-28 00:46:48 +08:00
from .nasnet import *
2018-10-27 06:54:13 +08:00
from .mobilenetv2 import *
from .shufflenet import *
2018-10-27 21:42:28 +08:00
from .squeezenet import *
2018-07-02 17:33:10 +08:00
from .mudeep import *
from .hacnn import *
2018-10-28 19:16:42 +08:00
from .pcb import *
2018-07-02 17:33:10 +08:00
__model_factory = {
# image classification models
2018-10-27 06:05:40 +08:00
'resnet50': resnet50,
'resnet50_fc512': resnet50_fc512,
'resnet50mid': resnet50mid,
2018-10-27 23:52:05 +08:00
'resnext50_32x4d': resnext50_32x4d,
'resnext101_32x4d': resnext101_32x4d,
2018-10-27 06:54:13 +08:00
'se_resnet50': se_resnet50,
'se_resnet50_fc512': se_resnet50_fc512,
'se_resnet101': se_resnet101,
'se_resnext50_32x4d': se_resnext50_32x4d,
'se_resnext101_32x4d': se_resnext101_32x4d,
2018-10-27 18:57:48 +08:00
'densenet121': densenet121,
'densenet121_fc512': densenet121_fc512,
2018-10-28 00:46:48 +08:00
'inceptionresnetv2': InceptionResNetV2,
'inceptionv4': inceptionv4,
'xception': xception,
2018-10-28 00:59:10 +08:00
# lightweight models
2018-10-28 00:46:48 +08:00
'nasnsetmobile': nasnetamobile,
2018-07-02 17:33:10 +08:00
'mobilenetv2': MobileNetV2,
'shufflenet': ShuffleNet,
2018-10-27 21:42:28 +08:00
'squeezenet1_0': squeezenet1_0,
'squeezenet1_0_fc512': squeezenet1_0_fc512,
2018-10-27 21:42:28 +08:00
'squeezenet1_1': squeezenet1_1,
# reid-specific models
2018-07-02 17:33:10 +08:00
'mudeep': MuDeep,
'hacnn': HACNN,
2018-10-28 19:16:42 +08:00
'pcb_p6': pcb_p6,
'pcb_p4': pcb_p4,
2018-07-02 17:33:10 +08:00
}
def get_names():
2018-07-05 00:16:18 +08:00
return list(__model_factory.keys())
2018-07-02 17:33:10 +08:00
def init_model(name, *args, **kwargs):
2018-07-05 00:16:18 +08:00
if name not in list(__model_factory.keys()):
2018-07-02 17:33:10 +08:00
raise KeyError("Unknown model: {}".format(name))
return __model_factory[name](*args, **kwargs)