fix save load, update det pretrain
parent
67d15d24d3
commit
3647c3188e
|
@ -78,11 +78,11 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中
|
|||
cd PaddleOCR/
|
||||
# 根据backbone的不同选择下载对应的预训练模型
|
||||
# 下载MobileNetV3的预训练模型
|
||||
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
|
||||
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams
|
||||
# 或,下载ResNet18_vd的预训练模型
|
||||
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams
|
||||
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams
|
||||
# 或,下载ResNet50_vd的预训练模型
|
||||
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams
|
||||
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams
|
||||
```
|
||||
|
||||
<a name="2-----"></a>
|
||||
|
|
|
@ -67,11 +67,11 @@ And the responding download link of backbone pretrain weights can be found in (h
|
|||
```shell
|
||||
cd PaddleOCR/
|
||||
# Download the pre-trained model of MobileNetV3
|
||||
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
|
||||
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams
|
||||
# or, download the pre-trained model of ResNet18_vd
|
||||
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams
|
||||
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams
|
||||
# or, download the pre-trained model of ResNet50_vd
|
||||
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams
|
||||
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -111,13 +111,16 @@ def load_pretrained_params(model, path):
|
|||
params = paddle.load(path + '.pdparams')
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||
new_state_dict[k1] = params[k2]
|
||||
for k1 in params.keys():
|
||||
if k1 not in state_dict.keys():
|
||||
logger.warning("The pretrained params {} not in model".format(k1))
|
||||
else:
|
||||
logger.warning(
|
||||
"The shape of model params {} {} not matched with loaded params {} {} !".
|
||||
format(k1, state_dict[k1].shape, k2, params[k2].shape))
|
||||
if list(state_dict[k1].shape) == list(params[k1].shape):
|
||||
new_state_dict[k1] = params[k1]
|
||||
else:
|
||||
logger.warning(
|
||||
"The shape of model params {} {} not matched with loaded params {} {} !".
|
||||
format(k1, state_dict[k1].shape, k1, params[k1].shape))
|
||||
model.set_state_dict(new_state_dict)
|
||||
logger.info("load pretrain successful from {}".format(path))
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue