119 lines
3.4 KiB
Python
119 lines
3.4 KiB
Python
from __future__ import absolute_import
|
|
|
|
import torch
|
|
|
|
from .resnet import *
|
|
from .resnetmid import *
|
|
from .resnet_ibn_a import *
|
|
from .resnet_ibn_b import *
|
|
from .senet import *
|
|
from .densenet import *
|
|
from .inceptionresnetv2 import *
|
|
from .inceptionv4 import *
|
|
from .xception import *
|
|
|
|
from .nasnet import *
|
|
from .mobilenetv2 import *
|
|
from .shufflenet import *
|
|
from .squeezenet import *
|
|
from .shufflenetv2 import *
|
|
|
|
from .mudeep import *
|
|
from .hacnn import *
|
|
from .pcb import *
|
|
from .mlfn import *
|
|
from .osnet import *
|
|
from .osnet_ain import *
|
|
|
|
|
|
__model_factory = {
|
|
# image classification models
|
|
'resnet18': resnet18,
|
|
'resnet34': resnet34,
|
|
'resnet50': resnet50,
|
|
'resnet101': resnet101,
|
|
'resnet152': resnet152,
|
|
'resnext50_32x4d': resnext50_32x4d,
|
|
'resnext101_32x8d': resnext101_32x8d,
|
|
'resnet50_fc512': resnet50_fc512,
|
|
'se_resnet50': se_resnet50,
|
|
'se_resnet50_fc512': se_resnet50_fc512,
|
|
'se_resnet101': se_resnet101,
|
|
'se_resnext50_32x4d': se_resnext50_32x4d,
|
|
'se_resnext101_32x4d': se_resnext101_32x4d,
|
|
'densenet121': densenet121,
|
|
'densenet169': densenet169,
|
|
'densenet201': densenet201,
|
|
'densenet161': densenet161,
|
|
'densenet121_fc512': densenet121_fc512,
|
|
'inceptionresnetv2': inceptionresnetv2,
|
|
'inceptionv4': inceptionv4,
|
|
'xception': xception,
|
|
'resnet50_ibn_a': resnet50_ibn_a,
|
|
'resnet50_ibn_b': resnet50_ibn_b,
|
|
# lightweight models
|
|
'nasnsetmobile': nasnetamobile,
|
|
'mobilenetv2_x1_0': mobilenetv2_x1_0,
|
|
'mobilenetv2_x1_4': mobilenetv2_x1_4,
|
|
'shufflenet': shufflenet,
|
|
'squeezenet1_0': squeezenet1_0,
|
|
'squeezenet1_0_fc512': squeezenet1_0_fc512,
|
|
'squeezenet1_1': squeezenet1_1,
|
|
'shufflenet_v2_x0_5': shufflenet_v2_x0_5,
|
|
'shufflenet_v2_x1_0': shufflenet_v2_x1_0,
|
|
'shufflenet_v2_x1_5': shufflenet_v2_x1_5,
|
|
'shufflenet_v2_x2_0': shufflenet_v2_x2_0,
|
|
# reid-specific models
|
|
'mudeep': MuDeep,
|
|
'resnet50mid': resnet50mid,
|
|
'hacnn': HACNN,
|
|
'pcb_p6': pcb_p6,
|
|
'pcb_p4': pcb_p4,
|
|
'mlfn': mlfn,
|
|
'osnet_x1_0': osnet_x1_0,
|
|
'osnet_x0_75': osnet_x0_75,
|
|
'osnet_x0_5': osnet_x0_5,
|
|
'osnet_x0_25': osnet_x0_25,
|
|
'osnet_ibn_x1_0': osnet_ibn_x1_0,
|
|
'osnet_ain_x1_0': osnet_ain_x1_0
|
|
}
|
|
|
|
|
|
def show_avai_models():
|
|
"""Displays available models.
|
|
|
|
Examples::
|
|
>>> from torchreid import models
|
|
>>> models.show_avai_models()
|
|
"""
|
|
print(list(__model_factory.keys()))
|
|
|
|
|
|
def build_model(name, num_classes, loss='softmax', pretrained=True, use_gpu=True):
|
|
"""A function wrapper for building a model.
|
|
|
|
Args:
|
|
name (str): model name.
|
|
num_classes (int): number of training identities.
|
|
loss (str, optional): loss function to optimize the model. Currently
|
|
supports "softmax" and "triplet". Default is "softmax".
|
|
pretrained (bool, optional): whether to load ImageNet-pretrained weights.
|
|
Default is True.
|
|
use_gpu (bool, optional): whether to use gpu. Default is True.
|
|
|
|
Returns:
|
|
nn.Module
|
|
|
|
Examples::
|
|
>>> from torchreid import models
|
|
>>> model = models.build_model('resnet50', 751, loss='softmax')
|
|
"""
|
|
avai_models = list(__model_factory.keys())
|
|
if name not in avai_models:
|
|
raise KeyError('Unknown model: {}. Must be one of {}'.format(name, avai_models))
|
|
return __model_factory[name](
|
|
num_classes=num_classes,
|
|
loss=loss,
|
|
pretrained=pretrained,
|
|
use_gpu=use_gpu
|
|
) |