support auto download model from bos (#9349)

This commit is contained in:
zhoujun 2023-03-08 19:21:28 +08:00 committed by GitHub
parent a0c7e63009
commit 623424fce0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 1 deletions

View File

@ -10,7 +10,7 @@ Global:
- 0
- 400
cal_metric_during_train: false
pretrained_model: null
pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams
checkpoints: null
save_inference_dir: null
use_visualdl: false

View File

@ -67,6 +67,20 @@ def maybe_download(model_storage_directory, url):
os.remove(tmp_path)
def maybe_download_params(model_path):
if os.path.exists(model_path):
return model_path
elif not is_link(model_path):
url = 'https://paddleocr.bj.bcebos.com/' + model_path
else:
url = model_path
tmp_path = os.path.join(MODELS_DIR, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path))
os.makedirs(MODELS_DIR, exist_ok=True)
download_with_progressbar(url, tmp_path)
return tmp_path
def is_link(s):
return s is not None and s.startswith('http')

View File

@ -24,6 +24,7 @@ import six
import paddle
from ppocr.utils.logging import get_logger
from ppocr.utils.network import maybe_download_params
__all__ = ['load_model']
@ -145,6 +146,7 @@ def load_model(config, model, optimizer=None, model_type='det'):
def load_pretrained_params(model, path):
logger = get_logger()
path = maybe_download_params(path)
if path.endswith('.pdparams'):
path = path.replace('.pdparams', '')
assert os.path.exists(path + ".pdparams"), \