删除backbone里的pretrained_model字段

pull/7832/head
zhiminzhang0830 2022-10-10 14:20:28 +08:00
parent 03802c7f93
commit 67ae525e95
2 changed files with 1 additions and 32 deletions
ppocr/modeling/backbones

View File

@ -8,7 +8,7 @@ Global:
# evaluation is run every 1260 iterations
eval_batch_step: [37800, 1260]
cal_metric_during_train: False
pretrained_model:
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained.pdparams
checkpoints:
save_inference_dir:
use_visualdl: False
@ -23,8 +23,6 @@ Architecture:
Backbone:
name: ResNet_vd
layers: 50
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained.pdparams
Neck:
name: FPN_UNet
in_channels: [256, 512, 1024, 2048]

View File

@ -16,8 +16,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle
from paddle import ParamAttr
import paddle.nn as nn
@ -27,8 +25,6 @@ from paddle.vision.ops import DeformConv2D
from paddle.regularizer import L2Decay
from paddle.nn.initializer import Normal, Constant, XavierUniform
from ppocr.utils.logging import get_logger
__all__ = ["ResNet_vd", "ConvBNLayer", "DeformableConvV2"]
@ -250,7 +246,6 @@ class ResNet_vd(nn.Layer):
layers=50,
dcn_stage=None,
out_indices=None,
pretrained_model=None,
**kwargs):
super(ResNet_vd, self).__init__()
@ -344,30 +339,6 @@ class ResNet_vd(nn.Layer):
self.out_channels.append(num_filters[block])
self.stages.append(nn.Sequential(*block_list))
if pretrained_model is not None:
self.load_pretrained_params(pretrained_model)
def load_pretrained_params(self, path):
logger = get_logger()
if path.endswith('.pdparams'):
path = path.replace('.pdparams', '')
assert os.path.exists(path + ".pdparams"), \
"The {}.pdparams does not exists!".format(path)
params = paddle.load(path + '.pdparams')
state_dict = self.state_dict()
new_state_dict = {}
for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2]
else:
logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
)
self.set_state_dict(new_state_dict)
logger.info(f"loaded backbone pretrained_model successful from {path}")
def forward(self, inputs):
y = self.conv1_1(inputs)
y = self.conv1_2(y)