support http pretrained
parent
a119eb4191
commit
9e9d8c9d2b
|
@ -51,10 +51,14 @@ def _extract_student_weights(all_params, student_prefix="Student."):
|
|||
|
||||
|
||||
def load_dygraph_pretrain(model, path=None):
|
||||
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
|
||||
raise ValueError("Model pretrain path {}.pdparams does not "
|
||||
if path.startswith(("http://", "https://")):
|
||||
path = get_weights_path_from_url(path)
|
||||
if not path.endswith('.pdparams'):
|
||||
path = path + '.pdparams'
|
||||
if not os.path.exists(path):
|
||||
raise ValueError("Model pretrain path {} does not "
|
||||
"exists.".format(path))
|
||||
param_state_dict = paddle.load(path + ".pdparams")
|
||||
param_state_dict = paddle.load(path)
|
||||
if isinstance(model, list):
|
||||
for m in model:
|
||||
if hasattr(m, 'set_dict'):
|
||||
|
|
Loading…
Reference in New Issue