2018-03-12 05:17:48 +08:00
|
|
|
from __future__ import absolute_import
|
|
|
|
|
|
|
|
from .ResNet import *
|
2018-05-10 18:46:59 +08:00
|
|
|
from .ResNeXt import *
|
|
|
|
from .SEResNet import *
|
2018-03-12 05:17:48 +08:00
|
|
|
from .DenseNet import *
|
2018-03-26 23:16:11 +08:00
|
|
|
from .MuDeep import *
|
2018-04-25 00:07:43 +08:00
|
|
|
from .HACNN import *
|
2018-04-28 18:14:14 +08:00
|
|
|
from .SqueezeNet import *
|
2018-04-29 21:19:24 +08:00
|
|
|
from .MobileNet import *
|
2018-04-30 19:29:56 +08:00
|
|
|
from .ShuffleNet import *
|
2018-04-30 22:56:46 +08:00
|
|
|
from .Xception import *
|
2018-05-05 04:39:32 +08:00
|
|
|
from .InceptionV4 import *
|
2018-05-04 21:06:08 +08:00
|
|
|
from .NASNet import *
|
2018-05-05 02:04:11 +08:00
|
|
|
from .DPN import *
|
2018-05-05 06:14:21 +08:00
|
|
|
from .InceptionResNetV2 import *
|
2018-03-12 05:17:48 +08:00
|
|
|
|
|
|
|
__factory = {
|
|
|
|
'resnet50': ResNet50,
|
2018-05-10 18:46:59 +08:00
|
|
|
'resnet101': ResNet101,
|
|
|
|
'seresnet50': SEResNet50,
|
|
|
|
'seresnet101': SEResNet101,
|
|
|
|
'seresnext50': SEResNeXt50,
|
|
|
|
'seresnext101': SEResNeXt101,
|
|
|
|
'resnext101': ResNeXt101_32x4d,
|
2018-03-12 06:36:46 +08:00
|
|
|
'resnet50m': ResNet50M,
|
2018-05-04 21:06:08 +08:00
|
|
|
'densenet121': DenseNet121,
|
2018-04-28 18:14:14 +08:00
|
|
|
'squeezenet': SqueezeNet,
|
2018-04-29 21:19:24 +08:00
|
|
|
'mobilenet': MobileNetV2,
|
2018-04-30 19:29:56 +08:00
|
|
|
'shufflenet': ShuffleNet,
|
2018-04-30 22:56:46 +08:00
|
|
|
'xception': Xception,
|
2018-05-03 17:47:47 +08:00
|
|
|
'inceptionv4': InceptionV4ReID,
|
2018-05-04 21:06:08 +08:00
|
|
|
'nasnet': NASNetAMobile,
|
2018-05-05 02:04:11 +08:00
|
|
|
'dpn92': DPN,
|
2018-05-05 06:14:21 +08:00
|
|
|
'inceptionresnetv2': InceptionResNetV2,
|
2018-05-10 18:46:59 +08:00
|
|
|
'mudeep': MuDeep,
|
|
|
|
'hacnn': HACNN,
|
2018-03-12 05:17:48 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
def get_names():
|
|
|
|
return __factory.keys()
|
|
|
|
|
|
|
|
def init_model(name, *args, **kwargs):
|
|
|
|
if name not in __factory.keys():
|
|
|
|
raise KeyError("Unknown model: {}".format(name))
|
|
|
|
return __factory[name](*args, **kwargs)
|