mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
support auto download model from bos (#9349)
This commit is contained in:
parent
a0c7e63009
commit
623424fce0
@ -10,7 +10,7 @@ Global:
|
|||||||
- 0
|
- 0
|
||||||
- 400
|
- 400
|
||||||
cal_metric_during_train: false
|
cal_metric_during_train: false
|
||||||
pretrained_model: null
|
pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams
|
||||||
checkpoints: null
|
checkpoints: null
|
||||||
save_inference_dir: null
|
save_inference_dir: null
|
||||||
use_visualdl: false
|
use_visualdl: false
|
||||||
|
@ -67,6 +67,20 @@ def maybe_download(model_storage_directory, url):
|
|||||||
os.remove(tmp_path)
|
os.remove(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_download_params(model_path):
|
||||||
|
if os.path.exists(model_path):
|
||||||
|
return model_path
|
||||||
|
elif not is_link(model_path):
|
||||||
|
url = 'https://paddleocr.bj.bcebos.com/' + model_path
|
||||||
|
else:
|
||||||
|
url = model_path
|
||||||
|
tmp_path = os.path.join(MODELS_DIR, url.split('/')[-1])
|
||||||
|
print('download {} to {}'.format(url, tmp_path))
|
||||||
|
os.makedirs(MODELS_DIR, exist_ok=True)
|
||||||
|
download_with_progressbar(url, tmp_path)
|
||||||
|
return tmp_path
|
||||||
|
|
||||||
|
|
||||||
def is_link(s):
|
def is_link(s):
|
||||||
return s is not None and s.startswith('http')
|
return s is not None and s.startswith('http')
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ import six
|
|||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
from ppocr.utils.logging import get_logger
|
from ppocr.utils.logging import get_logger
|
||||||
|
from ppocr.utils.network import maybe_download_params
|
||||||
|
|
||||||
__all__ = ['load_model']
|
__all__ = ['load_model']
|
||||||
|
|
||||||
@ -145,6 +146,7 @@ def load_model(config, model, optimizer=None, model_type='det'):
|
|||||||
|
|
||||||
def load_pretrained_params(model, path):
|
def load_pretrained_params(model, path):
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
path = maybe_download_params(path)
|
||||||
if path.endswith('.pdparams'):
|
if path.endswith('.pdparams'):
|
||||||
path = path.replace('.pdparams', '')
|
path = path.replace('.pdparams', '')
|
||||||
assert os.path.exists(path + ".pdparams"), \
|
assert os.path.exists(path + ".pdparams"), \
|
||||||
|
Loading…
x
Reference in New Issue
Block a user