support http pretrained

pull/2989/head
zhangyubo0722 2023-09-07 07:34:03 +00:00 committed by Tingquan Gao
parent a119eb4191
commit 9e9d8c9d2b
1 changed files with 7 additions and 3 deletions

View File

@ -51,10 +51,14 @@ def _extract_student_weights(all_params, student_prefix="Student."):
def load_dygraph_pretrain(model, path=None): def load_dygraph_pretrain(model, path=None):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): if path.startswith(("http://", "https://")):
raise ValueError("Model pretrain path {}.pdparams does not " path = get_weights_path_from_url(path)
if not path.endswith('.pdparams'):
path = path + '.pdparams'
if not os.path.exists(path):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path)) "exists.".format(path))
param_state_dict = paddle.load(path + ".pdparams") param_state_dict = paddle.load(path)
if isinstance(model, list): if isinstance(model, list):
for m in model: for m in model:
if hasattr(m, 'set_dict'): if hasattr(m, 'set_dict'):