mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
support auto download model from bos (#9349)
This commit is contained in:
parent
a0c7e63009
commit
623424fce0
@ -10,7 +10,7 @@ Global:
|
||||
- 0
|
||||
- 400
|
||||
cal_metric_during_train: false
|
||||
pretrained_model: null
|
||||
pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams
|
||||
checkpoints: null
|
||||
save_inference_dir: null
|
||||
use_visualdl: false
|
||||
|
@ -67,6 +67,20 @@ def maybe_download(model_storage_directory, url):
|
||||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def maybe_download_params(model_path):
|
||||
if os.path.exists(model_path):
|
||||
return model_path
|
||||
elif not is_link(model_path):
|
||||
url = 'https://paddleocr.bj.bcebos.com/' + model_path
|
||||
else:
|
||||
url = model_path
|
||||
tmp_path = os.path.join(MODELS_DIR, url.split('/')[-1])
|
||||
print('download {} to {}'.format(url, tmp_path))
|
||||
os.makedirs(MODELS_DIR, exist_ok=True)
|
||||
download_with_progressbar(url, tmp_path)
|
||||
return tmp_path
|
||||
|
||||
|
||||
def is_link(s):
|
||||
return s is not None and s.startswith('http')
|
||||
|
||||
|
@ -24,6 +24,7 @@ import six
|
||||
import paddle
|
||||
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils.network import maybe_download_params
|
||||
|
||||
__all__ = ['load_model']
|
||||
|
||||
@ -145,6 +146,7 @@ def load_model(config, model, optimizer=None, model_type='det'):
|
||||
|
||||
def load_pretrained_params(model, path):
|
||||
logger = get_logger()
|
||||
path = maybe_download_params(path)
|
||||
if path.endswith('.pdparams'):
|
||||
path = path.replace('.pdparams', '')
|
||||
assert os.path.exists(path + ".pdparams"), \
|
||||
|
Loading…
x
Reference in New Issue
Block a user