From 623424fce062bd4cc31ed84a8c470841775c05a8 Mon Sep 17 00:00:00 2001 From: zhoujun Date: Wed, 8 Mar 2023 19:21:28 +0800 Subject: [PATCH] support auto download model from bos (#9349) --- .../det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml | 2 +- ppocr/utils/network.py | 14 ++++++++++++++ ppocr/utils/save_load.py | 2 ++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml b/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml index 0e8af77647..083383a00f 100644 --- a/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml +++ b/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml @@ -10,7 +10,7 @@ Global: - 0 - 400 cal_metric_during_train: false - pretrained_model: null + pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams checkpoints: null save_inference_dir: null use_visualdl: false diff --git a/ppocr/utils/network.py b/ppocr/utils/network.py index 080a5d1601..327863f7d4 100644 --- a/ppocr/utils/network.py +++ b/ppocr/utils/network.py @@ -67,6 +67,20 @@ def maybe_download(model_storage_directory, url): os.remove(tmp_path) +def maybe_download_params(model_path): + if os.path.exists(model_path): + return model_path + elif not is_link(model_path): + url = 'https://paddleocr.bj.bcebos.com/' + model_path + else: + url = model_path + tmp_path = os.path.join(MODELS_DIR, url.split('/')[-1]) + print('download {} to {}'.format(url, tmp_path)) + os.makedirs(MODELS_DIR, exist_ok=True) + download_with_progressbar(url, tmp_path) + return tmp_path + + def is_link(s): return s is not None and s.startswith('http') diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index aa65f290c0..cc3f3c01e3 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -24,6 +24,7 @@ import six import paddle from ppocr.utils.logging import get_logger +from ppocr.utils.network import maybe_download_params __all__ = ['load_model'] @@ -145,6 +146,7 @@ def load_model(config, model, optimizer=None, model_type='det'): def load_pretrained_params(model, path): logger = get_logger() + path = maybe_download_params(path) if path.endswith('.pdparams'): path = path.replace('.pdparams', '') assert os.path.exists(path + ".pdparams"), \