120 lines
3.4 KiB
Python
Raw Normal View History

2018-07-02 10:33:10 +01:00
from __future__ import absolute_import
2019-03-19 17:26:08 +00:00
import torch
2019-12-01 02:35:44 +00:00
from .pcb import *
from .mlfn import *
from .hacnn import *
from .osnet import *
2018-10-26 23:54:13 +01:00
from .senet import *
2019-12-01 02:35:44 +00:00
from .mudeep import *
from .nasnet import *
from .resnet import *
2018-10-27 11:57:48 +01:00
from .densenet import *
2018-10-27 17:46:48 +01:00
from .xception import *
2019-12-01 02:35:44 +00:00
from .osnet_ain import *
from .resnetmid import *
2018-10-26 23:54:13 +01:00
from .shufflenet import *
2018-10-27 14:42:28 +01:00
from .squeezenet import *
2019-12-01 02:35:44 +00:00
from .inceptionv4 import *
from .mobilenetv2 import *
from .resnet_ibn_a import *
from .resnet_ibn_b import *
from .shufflenetv2 import *
2019-12-01 02:35:44 +00:00
from .inceptionresnetv2 import *
2018-07-02 10:33:10 +01:00
__model_factory = {
# image classification models
2019-03-24 15:52:39 +00:00
'resnet18': resnet18,
'resnet34': resnet34,
2018-10-26 23:05:40 +01:00
'resnet50': resnet50,
2019-03-24 15:52:39 +00:00
'resnet101': resnet101,
'resnet152': resnet152,
2018-10-27 16:52:05 +01:00
'resnext50_32x4d': resnext50_32x4d,
'resnext101_32x8d': resnext101_32x8d,
'resnet50_fc512': resnet50_fc512,
2018-10-26 23:54:13 +01:00
'se_resnet50': se_resnet50,
'se_resnet50_fc512': se_resnet50_fc512,
'se_resnet101': se_resnet101,
'se_resnext50_32x4d': se_resnext50_32x4d,
'se_resnext101_32x4d': se_resnext101_32x4d,
2018-10-27 11:57:48 +01:00
'densenet121': densenet121,
2019-03-24 15:52:39 +00:00
'densenet169': densenet169,
'densenet201': densenet201,
'densenet161': densenet161,
2018-10-27 11:57:48 +01:00
'densenet121_fc512': densenet121_fc512,
'inceptionresnetv2': inceptionresnetv2,
2018-10-27 17:46:48 +01:00
'inceptionv4': inceptionv4,
'xception': xception,
2019-11-10 21:31:28 +00:00
'resnet50_ibn_a': resnet50_ibn_a,
'resnet50_ibn_b': resnet50_ibn_b,
2018-10-27 17:59:10 +01:00
# lightweight models
2018-10-27 17:46:48 +01:00
'nasnsetmobile': nasnetamobile,
'mobilenetv2_x1_0': mobilenetv2_x1_0,
'mobilenetv2_x1_4': mobilenetv2_x1_4,
'shufflenet': shufflenet,
2018-10-27 14:42:28 +01:00
'squeezenet1_0': squeezenet1_0,
'squeezenet1_0_fc512': squeezenet1_0_fc512,
2018-10-27 14:42:28 +01:00
'squeezenet1_1': squeezenet1_1,
'shufflenet_v2_x0_5': shufflenet_v2_x0_5,
'shufflenet_v2_x1_0': shufflenet_v2_x1_0,
'shufflenet_v2_x1_5': shufflenet_v2_x1_5,
'shufflenet_v2_x2_0': shufflenet_v2_x2_0,
# reid-specific models
2018-07-02 10:33:10 +01:00
'mudeep': MuDeep,
'resnet50mid': resnet50mid,
2018-07-02 10:33:10 +01:00
'hacnn': HACNN,
2018-10-28 11:16:42 +00:00
'pcb_p6': pcb_p6,
'pcb_p4': pcb_p4,
2018-11-04 00:18:02 +00:00
'mlfn': mlfn,
2019-07-03 13:37:25 +01:00
'osnet_x1_0': osnet_x1_0,
'osnet_x0_75': osnet_x0_75,
'osnet_x0_5': osnet_x0_5,
'osnet_x0_25': osnet_x0_25,
2019-10-22 22:48:53 +01:00
'osnet_ibn_x1_0': osnet_ibn_x1_0,
'osnet_ain_x1_0': osnet_ain_x1_0
2018-07-02 10:33:10 +01:00
}
def show_avai_models():
"""Displays available models.
Examples::
>>> from torchreid import models
>>> models.show_avai_models()
"""
print(list(__model_factory.keys()))
2019-12-01 02:35:44 +00:00
def build_model(
name, num_classes, loss='softmax', pretrained=True, use_gpu=True
):
"""A function wrapper for building a model.
Args:
name (str): model name.
num_classes (int): number of training identities.
loss (str, optional): loss function to optimize the model. Currently
supports "softmax" and "triplet". Default is "softmax".
pretrained (bool, optional): whether to load ImageNet-pretrained weights.
Default is True.
use_gpu (bool, optional): whether to use gpu. Default is True.
Returns:
nn.Module
Examples::
>>> from torchreid import models
>>> model = models.build_model('resnet50', 751, loss='softmax')
"""
2019-03-19 17:26:08 +00:00
avai_models = list(__model_factory.keys())
2019-03-15 23:17:38 +00:00
if name not in avai_models:
2019-12-01 02:35:44 +00:00
raise KeyError(
'Unknown model: {}. Must be one of {}'.format(name, avai_models)
)
2019-03-19 17:26:08 +00:00
return __model_factory[name](
num_classes=num_classes,
loss=loss,
pretrained=pretrained,
use_gpu=use_gpu
2019-12-01 02:35:44 +00:00
)