diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index e48d2968b..391f67d3d 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -51,10 +51,14 @@ def _extract_student_weights(all_params, student_prefix="Student."): def load_dygraph_pretrain(model, path=None): - if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): - raise ValueError("Model pretrain path {}.pdparams does not " + if path.startswith(("http://", "https://")): + path = get_weights_path_from_url(path) + if not path.endswith('.pdparams'): + path = path + '.pdparams' + if not os.path.exists(path): + raise ValueError("Model pretrain path {} does not " "exists.".format(path)) - param_state_dict = paddle.load(path + ".pdparams") + param_state_dict = paddle.load(path) if isinstance(model, list): for m in model: if hasattr(m, 'set_dict'):