diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index 391f67d3d..35d11b83f 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -74,9 +74,13 @@ def load_dygraph_pretrain_from_url(model, use_ssld=False, use_imagenet22k_pretrained=False, use_imagenet22kto1k_pretrained=False): - if use_ssld: - pretrained_url = pretrained_url.replace("_pretrained", - "_ssld_pretrained") + if "ssld" not in pretrained_url: + if use_ssld: + pretrained_url = pretrained_url.replace("_pretrained", + "_ssld_pretrained") + else: + pretrained_url = pretrained_url.replace("ssld_pretrained", + "ssld_stage1_pretrained") if use_imagenet22k_pretrained: pretrained_url = pretrained_url.replace("_pretrained", "_22k_pretrained")