deep-person-reid/torchreid/models/__init__.py

104 lines
2.9 KiB
Python
Raw Normal View History

2018-07-02 17:33:10 +08:00
from __future__ import absolute_import
2019-03-20 01:26:08 +08:00
import torch
2018-07-02 17:33:10 +08:00
from .resnet import *
2018-10-27 06:05:40 +08:00
from .resnetmid import *
2018-10-27 23:52:05 +08:00
from .resnext import *
2018-10-27 06:54:13 +08:00
from .senet import *
2018-10-27 18:57:48 +08:00
from .densenet import *
2018-10-28 00:46:48 +08:00
from .inceptionresnetv2 import *
from .inceptionv4 import *
from .xception import *
2018-10-28 00:59:10 +08:00
2018-10-28 00:46:48 +08:00
from .nasnet import *
2018-10-27 06:54:13 +08:00
from .mobilenetv2 import *
from .shufflenet import *
2018-10-27 21:42:28 +08:00
from .squeezenet import *
2018-07-02 17:33:10 +08:00
from .mudeep import *
from .hacnn import *
2018-10-28 19:16:42 +08:00
from .pcb import *
2018-11-04 08:18:02 +08:00
from .mlfn import *
2018-07-02 17:33:10 +08:00
__model_factory = {
# image classification models
2019-03-24 23:52:39 +08:00
'resnet18': resnet18,
'resnet34': resnet34,
2018-10-27 06:05:40 +08:00
'resnet50': resnet50,
2019-03-24 23:52:39 +08:00
'resnet101': resnet101,
'resnet152': resnet152,
2018-10-27 06:05:40 +08:00
'resnet50_fc512': resnet50_fc512,
2018-10-27 23:52:05 +08:00
'resnext50_32x4d': resnext50_32x4d,
2018-11-22 20:41:22 +08:00
'resnext50_32x4d_fc512': resnext50_32x4d_fc512,
2018-10-27 06:54:13 +08:00
'se_resnet50': se_resnet50,
'se_resnet50_fc512': se_resnet50_fc512,
'se_resnet101': se_resnet101,
'se_resnext50_32x4d': se_resnext50_32x4d,
'se_resnext101_32x4d': se_resnext101_32x4d,
2018-10-27 18:57:48 +08:00
'densenet121': densenet121,
2019-03-24 23:52:39 +08:00
'densenet169': densenet169,
'densenet201': densenet201,
'densenet161': densenet161,
2018-10-27 18:57:48 +08:00
'densenet121_fc512': densenet121_fc512,
'inceptionresnetv2': inceptionresnetv2,
2018-10-28 00:46:48 +08:00
'inceptionv4': inceptionv4,
'xception': xception,
2018-10-28 00:59:10 +08:00
# lightweight models
2018-10-28 00:46:48 +08:00
'nasnsetmobile': nasnetamobile,
'mobilenetv2_1dot0': mobilenetv2_1dot0,
'mobilenetv2_1dot4': mobilenetv2_1dot4,
'shufflenet': shufflenet,
2018-10-27 21:42:28 +08:00
'squeezenet1_0': squeezenet1_0,
'squeezenet1_0_fc512': squeezenet1_0_fc512,
2018-10-27 21:42:28 +08:00
'squeezenet1_1': squeezenet1_1,
# reid-specific models
2018-07-02 17:33:10 +08:00
'mudeep': MuDeep,
'resnet50mid': resnet50mid,
2018-07-02 17:33:10 +08:00
'hacnn': HACNN,
2018-10-28 19:16:42 +08:00
'pcb_p6': pcb_p6,
'pcb_p4': pcb_p4,
2018-11-04 08:18:02 +08:00
'mlfn': mlfn,
2018-07-02 17:33:10 +08:00
}
def show_avai_models():
"""Displays available models.
Examples::
>>> from torchreid import models
>>> models.show_avai_models()
"""
print(list(__model_factory.keys()))
2019-03-20 01:26:08 +08:00
def build_model(name, num_classes, loss='softmax', pretrained=True, use_gpu=True):
"""A function wrapper for building a model.
Args:
name (str): model name.
num_classes (int): number of training identities.
loss (str, optional): loss function to optimize the model. Currently
supports "softmax" and "triplet". Default is "softmax".
pretrained (bool, optional): whether to load ImageNet-pretrained weights.
Default is True.
use_gpu (bool, optional): whether to use gpu. Default is True.
Returns:
nn.Module
Examples::
>>> from torchreid import models
>>> model = models.build_model('resnet50', 751, loss='softmax')
"""
2019-03-20 01:26:08 +08:00
avai_models = list(__model_factory.keys())
2019-03-16 07:17:38 +08:00
if name not in avai_models:
2019-03-20 01:26:08 +08:00
raise KeyError('Unknown model: {}. Must be one of {}'.format(name, avai_models))
print('Initializing model: {}'.format(name))
return __model_factory[name](
num_classes=num_classes,
loss=loss,
pretrained=pretrained,
use_gpu=use_gpu
)