using fixed format pretrained link
parent
789fcb2c12
commit
16c6152e49
32
hubconf.py
32
hubconf.py
|
@ -32,38 +32,10 @@ from ppcls.modeling.architectures import mobilenet_v3 as _mobilenet_v3
|
|||
from ppcls.modeling.architectures import resnext as _resnext
|
||||
|
||||
|
||||
def _load_pretrained_urls():
|
||||
'''Load pretrained model parameters url from README.md
|
||||
'''
|
||||
import re
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
readme_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'README.md')
|
||||
|
||||
with open(readme_path, 'r') as f:
|
||||
lines = f.readlines()
|
||||
lines = [lin for lin in lines if lin.strip().startswith('|') and 'Download link' in lin]
|
||||
|
||||
urls = OrderedDict()
|
||||
for lin in lines:
|
||||
try:
|
||||
name = re.findall(r'\|(.*?)\|', lin)[0].strip().replace('<br>', '')
|
||||
url = re.findall(r'\((.*?)\)', lin)[-1].strip()
|
||||
if name in url:
|
||||
urls[name] = url
|
||||
except:
|
||||
pass
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
_checkpoints = _load_pretrained_urls()
|
||||
|
||||
|
||||
def _load_pretrained_parameters(model, name):
|
||||
assert name in _checkpoints, 'Not provide {} pretrained model.'.format(name)
|
||||
path = paddle.utils.download.get_weights_path_from_url(_checkpoints[name])
|
||||
url = 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/{}_pretrained.pdparams'.format(name)
|
||||
path = paddle.utils.download.get_weights_path_from_url(url)
|
||||
model.set_state_dict(paddle.load(path))
|
||||
return model
|
||||
|
||||
|
|
Loading…
Reference in New Issue