diff --git a/ppcls/arch/backbone/variant_models/resnet_variant.py b/ppcls/arch/backbone/variant_models/resnet_variant.py index 81eb71bb8..08042ad58 100644 --- a/ppcls/arch/backbone/variant_models/resnet_variant.py +++ b/ppcls/arch/backbone/variant_models/resnet_variant.py @@ -1,5 +1,5 @@ from paddle.nn import Conv2D -from ppcls.arch.backbone.legendary_models.resnet import ResNet50 +from ppcls.arch.backbone.legendary_models.resnet import ResNet50, MODEL_URLS, _load_pretrained __all__ = ["ResNet50_last_stage_stride1"] @@ -17,6 +17,7 @@ def ResNet50_last_stage_stride1(pretrained=False, use_ssld=False, **kwargs): return new_conv match_re = "conv2d_4[4|6]" - model = ResNet50(pretrained=pretrained, use_ssld=use_ssld, **kwargs) + model = ResNet50(pretrained=False, use_ssld=use_ssld, **kwargs) model.replace_sub(match_re, replace_function, True) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld) return model