From 216291618c5a773b7792cb6a66a5d161a11263b1 Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Fri, 6 Jan 2023 12:05:58 +0000 Subject: [PATCH] add the pretrained url --- ppcls/arch/backbone/legendary_models/resnet.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/ppcls/arch/backbone/legendary_models/resnet.py b/ppcls/arch/backbone/legendary_models/resnet.py index c38651f46..7a4f3b37a 100644 --- a/ppcls/arch/backbone/legendary_models/resnet.py +++ b/ppcls/arch/backbone/legendary_models/resnet.py @@ -34,6 +34,8 @@ from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_fro MODEL_URLS = { "ResNet18": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams", + "ResNet18_dbb": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_dbb_pretrained.pdparams", "ResNet18_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_vd_pretrained.pdparams", "ResNet34": @@ -429,7 +431,10 @@ def _load_pretrained(pretrained, model, model_url, use_ssld): ) -def ResNet18(pretrained=False, use_ssld=False, **kwargs): +def ResNet18(pretrained=False, + use_ssld=False, + layer_type="ConvBNLayer", + **kwargs): """ ResNet18 Args: @@ -443,8 +448,13 @@ def ResNet18(pretrained=False, use_ssld=False, **kwargs): config=NET_CONFIG["18"], stages_pattern=MODEL_STAGES_PATTERN["ResNet18"], version="vb", + layer_type=layer_type, **kwargs) - _load_pretrained(pretrained, model, MODEL_URLS["ResNet18"], use_ssld) + if layer_type == "DiverseBranchBlock": + _load_pretrained(pretrained, model, MODEL_URLS["ResNet18_dbb"], + use_ssld) + else: + _load_pretrained(pretrained, model, MODEL_URLS["ResNet18"], use_ssld) return model