fix save load, update det pretrain
parent
67d15d24d3
commit
3647c3188e
|
@ -78,11 +78,11 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中
|
||||||
cd PaddleOCR/
|
cd PaddleOCR/
|
||||||
# 根据backbone的不同选择下载对应的预训练模型
|
# 根据backbone的不同选择下载对应的预训练模型
|
||||||
# 下载MobileNetV3的预训练模型
|
# 下载MobileNetV3的预训练模型
|
||||||
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
|
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams
|
||||||
# 或,下载ResNet18_vd的预训练模型
|
# 或,下载ResNet18_vd的预训练模型
|
||||||
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams
|
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams
|
||||||
# 或,下载ResNet50_vd的预训练模型
|
# 或,下载ResNet50_vd的预训练模型
|
||||||
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams
|
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams
|
||||||
```
|
```
|
||||||
|
|
||||||
<a name="2-----"></a>
|
<a name="2-----"></a>
|
||||||
|
|
|
@ -67,11 +67,11 @@ And the responding download link of backbone pretrain weights can be found in (h
|
||||||
```shell
|
```shell
|
||||||
cd PaddleOCR/
|
cd PaddleOCR/
|
||||||
# Download the pre-trained model of MobileNetV3
|
# Download the pre-trained model of MobileNetV3
|
||||||
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
|
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams
|
||||||
# or, download the pre-trained model of ResNet18_vd
|
# or, download the pre-trained model of ResNet18_vd
|
||||||
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams
|
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams
|
||||||
# or, download the pre-trained model of ResNet50_vd
|
# or, download the pre-trained model of ResNet50_vd
|
||||||
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams
|
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -111,13 +111,16 @@ def load_pretrained_params(model, path):
|
||||||
params = paddle.load(path + '.pdparams')
|
params = paddle.load(path + '.pdparams')
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
for k1 in params.keys():
|
||||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
if k1 not in state_dict.keys():
|
||||||
new_state_dict[k1] = params[k2]
|
logger.warning("The pretrained params {} not in model".format(k1))
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
if list(state_dict[k1].shape) == list(params[k1].shape):
|
||||||
"The shape of model params {} {} not matched with loaded params {} {} !".
|
new_state_dict[k1] = params[k1]
|
||||||
format(k1, state_dict[k1].shape, k2, params[k2].shape))
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"The shape of model params {} {} not matched with loaded params {} {} !".
|
||||||
|
format(k1, state_dict[k1].shape, k1, params[k1].shape))
|
||||||
model.set_state_dict(new_state_dict)
|
model.set_state_dict(new_state_dict)
|
||||||
logger.info("load pretrain successful from {}".format(path))
|
logger.info("load pretrain successful from {}".format(path))
|
||||||
return model
|
return model
|
||||||
|
|
Loading…
Reference in New Issue