add load_dygraph_pretrain_from_url function
parent
5671f9d9bb
commit
31e59dfa1b
|
@ -24,6 +24,7 @@ import tempfile
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from paddle.static import load_program_state
|
from paddle.static import load_program_state
|
||||||
|
from paddle.utils.download import get_weights_path_from_url
|
||||||
|
|
||||||
from ppcls.utils import logger
|
from ppcls.utils import logger
|
||||||
|
|
||||||
|
@ -70,6 +71,14 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def load_dygraph_pretrain_from_url(model, pretrained_url, use_ssld, load_static_weights=False):
|
||||||
|
if use_ssld:
|
||||||
|
pretrained_url = pretrained_url.replace("_pretrained", "_ssld_pretrained")
|
||||||
|
local_weight_path = get_weights_path_from_url(pretrained_url).replace(".pdparams", "")
|
||||||
|
load_dygraph_pretrain(model, path=local_weight_path, load_static_weights=load_static_weights)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def load_distillation_model(model, pretrained_model, load_static_weights):
|
def load_distillation_model(model, pretrained_model, load_static_weights):
|
||||||
logger.info("In distillation mode, teacher model will be "
|
logger.info("In distillation mode, teacher model will be "
|
||||||
"loaded firstly before student model.")
|
"loaded firstly before student model.")
|
||||||
|
|
Loading…
Reference in New Issue