merge init_model and load_dygraph_params to load_model (#4623)

* merge init_model and load_dygraph_params to load_model
pull/4635/head
zhoujun 2021-11-12 11:06:36 +08:00 committed by GitHub
parent 1417a3c2cf
commit ae4167dc32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 47 additions and 89 deletions

View File

@ -30,7 +30,7 @@ from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
import tools.program as program
@ -89,7 +89,7 @@ def main(config, device, logger, vdl_writer):
logger.info(f"FLOPs after pruning: {flops}")
# load pretrain model
pre_best_model_dict = init_model(config, model, logger, None)
load_model(config, model)
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class)
logger.info(f"metric['hmean']: {metric['hmean']}")

View File

@ -32,7 +32,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
import tools.program as program
dist.get_world_size()
@ -94,7 +94,7 @@ def main(config, device, logger, vdl_writer):
# build metric
eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer)
pre_best_model_dict = load_model(config, model, optimizer)
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader)))

View File

@ -28,7 +28,7 @@ from paddle.jit import to_static
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser
from ppocr.metrics import build_metric
@ -101,7 +101,7 @@ def main():
quanter = QAT(config=quant_config)
quanter.quantize(model)
init_model(config, model)
load_model(config, model)
model.eval()
# build metric

View File

@ -37,7 +37,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
import tools.program as program
from paddleslim.dygraph.quant import QAT
@ -137,7 +137,7 @@ def main(config, device, logger, vdl_writer):
# build metric
eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer)
pre_best_model_dict = load_model(config, model, optimizer)
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader)))

View File

@ -37,7 +37,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
import tools.program as program
import paddleslim
from paddleslim.dygraph.quant import QAT

View File

@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head
from .base_model import BaseModel
from ppocr.utils.save_load import init_model, load_pretrained_params
from ppocr.utils.save_load import load_pretrained_params
__all__ = ['DistillationModel']

View File

@ -25,7 +25,7 @@ import paddle
from ppocr.utils.logging import get_logger
__all__ = ['init_model', 'save_model', 'load_dygraph_params']
__all__ = ['load_model']
def _mkdir_if_not_exist(path, logger):
@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path))
def init_model(config, model, optimizer=None, lr_scheduler=None):
def load_model(config, model, optimizer=None):
"""
load model from checkpoint or pretrained_model
"""
@ -54,15 +54,14 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
pretrained_model = global_config.get('pretrained_model')
best_model_dict = {}
if checkpoints:
assert os.path.exists(checkpoints + ".pdparams"), \
"Given dir {}.pdparams not exist.".format(checkpoints)
if checkpoints.endswith('pdparams'):
checkpoints = checkpoints.replace('.pdparams', '')
assert os.path.exists(checkpoints + ".pdopt"), \
"Given dir {}.pdopt not exist.".format(checkpoints)
para_dict = paddle.load(checkpoints + '.pdparams')
opti_dict = paddle.load(checkpoints + '.pdopt')
model.set_state_dict(para_dict)
f"The {checkpoints}.pdopt does not exists!"
load_pretrained_params(model, checkpoints)
optim_dict = paddle.load(checkpoints + '.pdopt')
if optimizer is not None:
optimizer.set_state_dict(opti_dict)
optimizer.set_state_dict(optim_dict)
if os.path.exists(checkpoints + '.states'):
with open(checkpoints + '.states', 'rb') as f:
@ -73,70 +72,31 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model]
for pretrained in pretrained_model:
if not (os.path.isdir(pretrained) or
os.path.exists(pretrained + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(pretrained))
param_state_dict = paddle.load(pretrained + '.pdparams')
model.set_state_dict(param_state_dict)
logger.info("load pretrained model from {}".format(
pretrained_model))
load_pretrained_params(model, pretrained_model)
else:
logger.info('train from scratch')
return best_model_dict
def load_dygraph_params(config, model, logger, optimizer):
ckp = config['Global']['checkpoints']
if ckp and os.path.exists(ckp + ".pdparams"):
pre_best_model_dict = init_model(config, model, optimizer)
return pre_best_model_dict
else:
pm = config['Global']['pretrained_model']
if pm is None:
return {}
if not os.path.exists(pm) and not os.path.exists(pm + ".pdparams"):
logger.info(f"The pretrained_model {pm} does not exists!")
return {}
pm = pm if pm.endswith('.pdparams') else pm + '.pdparams'
params = paddle.load(pm)
state_dict = model.state_dict()
new_state_dict = {}
for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2]
else:
logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
)
model.set_state_dict(new_state_dict)
logger.info(f"loaded pretrained_model successful from {pm}")
return {}
def load_pretrained_params(model, path):
if path is None:
return False
if not os.path.exists(path) and not os.path.exists(path + ".pdparams"):
print(f"The pretrained_model {path} does not exists!")
return False
logger = get_logger()
if path.endswith('pdparams'):
path = path.replace('.pdparams', '')
assert os.path.exists(path + ".pdparams"), \
f"The {path}.pdparams does not exists!"
path = path if path.endswith('.pdparams') else path + '.pdparams'
params = paddle.load(path)
params = paddle.load(path + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2]
else:
print(
logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
)
model.set_state_dict(new_state_dict)
print(f"load pretrain successful from {path}")
logger.info(f"load pretrain successful from {path}")
return model

View File

@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model, load_dygraph_params
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import print_dict
import tools.program as program
@ -60,7 +60,7 @@ def main():
else:
model_type = None
best_model_dict = load_dygraph_params(config, model, logger, None)
best_model_dict = load_model(config, model)
if len(best_model_dict):
logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items():

View File

@ -27,7 +27,7 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model, load_dygraph_params
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import print_dict
import tools.program as program
@ -57,7 +57,7 @@ def main():
model = build_model(config['Architecture'])
best_model_dict = load_dygraph_params(config, model, logger, None)
best_model_dict = load_model(config, model)
if len(best_model_dict):
logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items():

View File

@ -26,7 +26,7 @@ from paddle.jit import to_static
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser
@ -107,7 +107,7 @@ def main():
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"])
init_model(config, model)
load_model(config, model)
model.eval()
save_path = config["Global"]["save_inference_dir"]

View File

@ -32,7 +32,7 @@ import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
@ -47,7 +47,7 @@ def main():
# build model
model = build_model(config['Architecture'])
init_model(config, model)
load_model(config, model)
# create data ops
transforms = []

View File

@ -34,7 +34,7 @@ import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model, load_dygraph_params
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
@ -59,7 +59,7 @@ def main():
# build model
model = build_model(config['Architecture'])
_ = load_dygraph_params(config, model, logger, None)
load_model(config, model)
# build post process
post_process_class = build_post_process(config['PostProcess'])

View File

@ -34,7 +34,7 @@ import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
@ -68,7 +68,7 @@ def main():
# build model
model = build_model(config['Architecture'])
init_model(config, model)
load_model(config, model)
# build post process
post_process_class = build_post_process(config['PostProcess'],

View File

@ -33,7 +33,7 @@ import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
@ -58,7 +58,7 @@ def main():
model = build_model(config['Architecture'])
init_model(config, model)
load_model(config, model)
# create data ops
transforms = []
@ -75,9 +75,7 @@ def main():
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
]
elif config['Architecture']['algorithm'] == "SAR":
op[op_name]['keep_keys'] = [
'image', 'valid_ratio'
]
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
else:
op[op_name]['keep_keys'] = ['image']
transforms.append(op)

View File

@ -34,11 +34,12 @@ from paddle.jit import to_static
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
import cv2
def main(config, device, logger, vdl_writer):
global_config = config['Global']
@ -53,7 +54,7 @@ def main(config, device, logger, vdl_writer):
model = build_model(config['Architecture'])
init_model(config, model, logger)
load_model(config, model)
# create data ops
transforms = []
@ -104,4 +105,3 @@ def main(config, device, logger, vdl_writer):
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess()
main(config, device, logger, vdl_writer)

View File

@ -35,7 +35,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model, load_dygraph_params
from ppocr.utils.save_load import load_model
import tools.program as program
dist.get_world_size()
@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
# build metric
eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
pre_best_model_dict = load_model(config, model, optimizer)
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format(