PaddleClas/hubconf.py

229 lines
6.6 KiB
Python
Raw Normal View History

2021-03-24 10:40:57 +08:00
2021-03-24 16:57:21 +08:00
dependencies = ['paddle', 'numpy']
2021-03-24 10:40:57 +08:00
2021-03-25 13:38:05 +08:00
import paddle
2021-03-24 10:40:57 +08:00
2021-03-25 14:38:31 +08:00
from ppcls.modeling.architectures import alexnet as _alexnet
from ppcls.modeling.architectures import vgg as _vgg
2021-03-25 14:26:52 +08:00
from ppcls.modeling.architectures import resnet as _resnet
2021-03-25 14:48:00 +08:00
from ppcls.modeling.architectures import squeezenet as _squeezenet
from ppcls.modeling.architectures import densenet as _densenet
from ppcls.modeling.architectures import inception_v3 as _inception_v3
from ppcls.modeling.architectures import inception_v4 as _inception_v4
2021-03-25 15:10:17 +08:00
from ppcls.modeling.architectures import googlenet as _googlenet
from ppcls.modeling.architectures import shufflenet_v2 as _shufflenet_v2
from ppcls.modeling.architectures import mobilenet_v1 as _mobilenet_v1
from ppcls.modeling.architectures import mobilenet_v2 as _mobilenet_v2
from ppcls.modeling.architectures import mobilenet_v3 as _mobilenet_v3
from ppcls.modeling.architectures import resnext as _resnext
2021-03-25 14:48:00 +08:00
2021-03-25 14:26:52 +08:00
# _checkpoints = {
# 'ResNet18': 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_pretrained.pdparams',
# 'ResNet34': 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet34_pretrained.pdparams',
# }
2021-03-25 13:38:05 +08:00
2021-03-25 14:48:00 +08:00
2021-03-25 14:26:52 +08:00
def _load_pretrained_urls():
'''Load pretrained model parameters url from README.md
'''
import re
2021-03-25 14:48:00 +08:00
import os
2021-03-25 14:26:52 +08:00
from collections import OrderedDict
2021-03-25 14:48:00 +08:00
readme_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'README.md')
with open(readme_path, 'r') as f:
2021-03-25 14:26:52 +08:00
lines = f.readlines()
lines = [lin for lin in lines if lin.strip().startswith('|') and 'Download link' in lin]
urls = OrderedDict()
for lin in lines:
try:
name = re.findall(r'\|(.*?)\|', lin)[0].strip().replace('<br>', '')
url = re.findall(r'\((.*?)\)', lin)[-1].strip()
if name in url:
urls[name] = url
except:
pass
return urls
_checkpoints = _load_pretrained_urls()
2021-03-25 13:38:05 +08:00
2021-03-24 16:37:50 +08:00
2021-03-25 14:38:31 +08:00
def AlexNet(**kwargs):
'''AlexNet
'''
pretrained = kwargs.pop('pretrained', False)
model = _alexnet.AlexNet(**kwargs)
if pretrained:
assert 'AlexNet' in _checkpoints, 'Not provide `AlexNet` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['AlexNet'])
model.set_state_dict(paddle.load(path))
return model
def VGG11(**kwargs):
'''VGG11
'''
pretrained = kwargs.pop('pretrained', False)
model = _vgg.VGG11(**kwargs)
if pretrained:
assert 'VGG11' in _checkpoints, 'Not provide `VGG11` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['VGG11'])
model.set_state_dict(paddle.load(path))
return model
def VGG13(**kwargs):
'''VGG13
'''
pretrained = kwargs.pop('pretrained', False)
model = _vgg.VGG13(**kwargs)
if pretrained:
assert 'VGG13' in _checkpoints, 'Not provide `VGG13` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['VGG13'])
model.set_state_dict(paddle.load(path))
return model
def VGG16(**kwargs):
'''VGG16
'''
pretrained = kwargs.pop('pretrained', False)
model = _vgg.VGG16(**kwargs)
if pretrained:
assert 'VGG16' in _checkpoints, 'Not provide `VGG16` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['VGG16'])
model.set_state_dict(paddle.load(path))
return model
def VGG19(**kwargs):
'''VGG19
'''
pretrained = kwargs.pop('pretrained', False)
model = _vgg.VGG19(**kwargs)
if pretrained:
assert 'VGG19' in _checkpoints, 'Not provide `VGG19` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['VGG19'])
model.set_state_dict(paddle.load(path))
return model
2021-03-24 16:37:50 +08:00
def ResNet18(**kwargs):
'''ResNet18
2021-03-24 10:40:57 +08:00
'''
2021-03-25 13:38:05 +08:00
pretrained = kwargs.pop('pretrained', False)
2021-03-25 14:26:52 +08:00
model = _resnet.ResNet18(**kwargs)
2021-03-25 13:38:05 +08:00
if pretrained:
2021-03-25 14:26:52 +08:00
assert 'ResNet18' in _checkpoints, 'Not provide `ResNet18` pretrained model.'
2021-03-25 13:38:05 +08:00
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet18'])
model.set_state_dict(paddle.load(path))
2021-03-24 10:40:57 +08:00
return model
2021-03-24 16:37:50 +08:00
def ResNet34(**kwargs):
'''ResNet34
'''
2021-03-25 14:26:52 +08:00
pretrained = kwargs.pop('pretrained', False)
model = _resnet.ResNet34(**kwargs)
if pretrained:
assert 'ResNet34' in _checkpoints, 'Not provide `ResNet34` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet34'])
model.set_state_dict(paddle.load(path))
2021-03-24 16:37:50 +08:00
2021-03-25 14:26:52 +08:00
return model
2021-03-24 16:37:50 +08:00
2021-03-25 14:38:31 +08:00
def ResNet50(**kwargs):
'''ResNet50
'''
pretrained = kwargs.pop('pretrained', False)
model = _resnet.ResNet50(**kwargs)
if pretrained:
assert 'ResNet50' in _checkpoints, 'Not provide `ResNet50` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet50'])
model.set_state_dict(paddle.load(path))
return model
def ResNet101(**kwargs):
'''ResNet101
'''
pretrained = kwargs.pop('pretrained', False)
model = _resnet.ResNet101(**kwargs)
if pretrained:
assert 'ResNet101' in _checkpoints, 'Not provide `ResNet101` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet101'])
model.set_state_dict(paddle.load(path))
return model
def ResNet152(**kwargs):
'''ResNet152
'''
pretrained = kwargs.pop('pretrained', False)
model = _resnet.ResNet152(**kwargs)
if pretrained:
assert 'ResNet152' in _checkpoints, 'Not provide `ResNet152` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet152'])
model.set_state_dict(paddle.load(path))
return model
2021-03-25 15:10:17 +08:00
def SqueezeNet1_0(**kwargs):
'''SqueezeNet1_0
'''
pretrained = kwargs.pop('pretrained', False)
model = _squeezenet.SqueezeNet1_0(**kwargs)
if pretrained:
assert 'SqueezeNet1_0' in _checkpoints, 'Not provide `SqueezeNet1_0` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['SqueezeNet1_0'])
model.set_state_dict(paddle.load(path))
return model
def SqueezeNet1_1(**kwargs):
'''SqueezeNet1_1
'''
pretrained = kwargs.pop('pretrained', False)
model = _squeezenet.SqueezeNet1_1(**kwargs)
if pretrained:
assert 'SqueezeNet1_1' in _checkpoints, 'Not provide `SqueezeNet1_1` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['SqueezeNet1_1'])
model.set_state_dict(paddle.load(path))
return model