merge init_model and load_dygraph_params to load_model (#4623)
* merge init_model and load_dygraph_params to load_modelpull/4635/head
parent
1417a3c2cf
commit
ae4167dc32
|
@ -30,7 +30,7 @@ from ppocr.modeling.architectures import build_model
|
|||
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.metrics import build_metric
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import load_model
|
||||
import tools.program as program
|
||||
|
||||
|
||||
|
@ -89,7 +89,7 @@ def main(config, device, logger, vdl_writer):
|
|||
logger.info(f"FLOPs after pruning: {flops}")
|
||||
|
||||
# load pretrain model
|
||||
pre_best_model_dict = init_model(config, model, logger, None)
|
||||
load_model(config, model)
|
||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class)
|
||||
logger.info(f"metric['hmean']: {metric['hmean']}")
|
||||
|
|
|
@ -32,7 +32,7 @@ from ppocr.losses import build_loss
|
|||
from ppocr.optimizer import build_optimizer
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.metrics import build_metric
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import load_model
|
||||
import tools.program as program
|
||||
|
||||
dist.get_world_size()
|
||||
|
@ -94,7 +94,7 @@ def main(config, device, logger, vdl_writer):
|
|||
# build metric
|
||||
eval_class = build_metric(config['Metric'])
|
||||
# load pretrain model
|
||||
pre_best_model_dict = init_model(config, model, logger, optimizer)
|
||||
pre_best_model_dict = load_model(config, model, optimizer)
|
||||
|
||||
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
|
||||
format(len(train_dataloader), len(valid_dataloader)))
|
||||
|
|
|
@ -28,7 +28,7 @@ from paddle.jit import to_static
|
|||
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import load_model
|
||||
from ppocr.utils.logging import get_logger
|
||||
from tools.program import load_config, merge_config, ArgsParser
|
||||
from ppocr.metrics import build_metric
|
||||
|
@ -101,7 +101,7 @@ def main():
|
|||
quanter = QAT(config=quant_config)
|
||||
quanter.quantize(model)
|
||||
|
||||
init_model(config, model)
|
||||
load_model(config, model)
|
||||
model.eval()
|
||||
|
||||
# build metric
|
||||
|
|
|
@ -37,7 +37,7 @@ from ppocr.losses import build_loss
|
|||
from ppocr.optimizer import build_optimizer
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.metrics import build_metric
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import load_model
|
||||
import tools.program as program
|
||||
from paddleslim.dygraph.quant import QAT
|
||||
|
||||
|
@ -137,7 +137,7 @@ def main(config, device, logger, vdl_writer):
|
|||
# build metric
|
||||
eval_class = build_metric(config['Metric'])
|
||||
# load pretrain model
|
||||
pre_best_model_dict = init_model(config, model, logger, optimizer)
|
||||
pre_best_model_dict = load_model(config, model, optimizer)
|
||||
|
||||
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
|
||||
format(len(train_dataloader), len(valid_dataloader)))
|
||||
|
|
|
@ -37,7 +37,7 @@ from ppocr.losses import build_loss
|
|||
from ppocr.optimizer import build_optimizer
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.metrics import build_metric
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import load_model
|
||||
import tools.program as program
|
||||
import paddleslim
|
||||
from paddleslim.dygraph.quant import QAT
|
||||
|
|
|
@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
|
|||
from ppocr.modeling.necks import build_neck
|
||||
from ppocr.modeling.heads import build_head
|
||||
from .base_model import BaseModel
|
||||
from ppocr.utils.save_load import init_model, load_pretrained_params
|
||||
from ppocr.utils.save_load import load_pretrained_params
|
||||
|
||||
__all__ = ['DistillationModel']
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ import paddle
|
|||
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
__all__ = ['init_model', 'save_model', 'load_dygraph_params']
|
||||
__all__ = ['load_model']
|
||||
|
||||
|
||||
def _mkdir_if_not_exist(path, logger):
|
||||
|
@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger):
|
|||
raise OSError('Failed to mkdir {}'.format(path))
|
||||
|
||||
|
||||
def init_model(config, model, optimizer=None, lr_scheduler=None):
|
||||
def load_model(config, model, optimizer=None):
|
||||
"""
|
||||
load model from checkpoint or pretrained_model
|
||||
"""
|
||||
|
@ -54,15 +54,14 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
|
|||
pretrained_model = global_config.get('pretrained_model')
|
||||
best_model_dict = {}
|
||||
if checkpoints:
|
||||
assert os.path.exists(checkpoints + ".pdparams"), \
|
||||
"Given dir {}.pdparams not exist.".format(checkpoints)
|
||||
if checkpoints.endswith('pdparams'):
|
||||
checkpoints = checkpoints.replace('.pdparams', '')
|
||||
assert os.path.exists(checkpoints + ".pdopt"), \
|
||||
"Given dir {}.pdopt not exist.".format(checkpoints)
|
||||
para_dict = paddle.load(checkpoints + '.pdparams')
|
||||
opti_dict = paddle.load(checkpoints + '.pdopt')
|
||||
model.set_state_dict(para_dict)
|
||||
f"The {checkpoints}.pdopt does not exists!"
|
||||
load_pretrained_params(model, checkpoints)
|
||||
optim_dict = paddle.load(checkpoints + '.pdopt')
|
||||
if optimizer is not None:
|
||||
optimizer.set_state_dict(opti_dict)
|
||||
optimizer.set_state_dict(optim_dict)
|
||||
|
||||
if os.path.exists(checkpoints + '.states'):
|
||||
with open(checkpoints + '.states', 'rb') as f:
|
||||
|
@ -73,70 +72,31 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
|
|||
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
|
||||
logger.info("resume from {}".format(checkpoints))
|
||||
elif pretrained_model:
|
||||
if not isinstance(pretrained_model, list):
|
||||
pretrained_model = [pretrained_model]
|
||||
for pretrained in pretrained_model:
|
||||
if not (os.path.isdir(pretrained) or
|
||||
os.path.exists(pretrained + '.pdparams')):
|
||||
raise ValueError("Model pretrain path {} does not "
|
||||
"exists.".format(pretrained))
|
||||
param_state_dict = paddle.load(pretrained + '.pdparams')
|
||||
model.set_state_dict(param_state_dict)
|
||||
logger.info("load pretrained model from {}".format(
|
||||
pretrained_model))
|
||||
load_pretrained_params(model, pretrained_model)
|
||||
else:
|
||||
logger.info('train from scratch')
|
||||
return best_model_dict
|
||||
|
||||
|
||||
def load_dygraph_params(config, model, logger, optimizer):
|
||||
ckp = config['Global']['checkpoints']
|
||||
if ckp and os.path.exists(ckp + ".pdparams"):
|
||||
pre_best_model_dict = init_model(config, model, optimizer)
|
||||
return pre_best_model_dict
|
||||
else:
|
||||
pm = config['Global']['pretrained_model']
|
||||
if pm is None:
|
||||
return {}
|
||||
if not os.path.exists(pm) and not os.path.exists(pm + ".pdparams"):
|
||||
logger.info(f"The pretrained_model {pm} does not exists!")
|
||||
return {}
|
||||
pm = pm if pm.endswith('.pdparams') else pm + '.pdparams'
|
||||
params = paddle.load(pm)
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||
new_state_dict[k1] = params[k2]
|
||||
else:
|
||||
logger.info(
|
||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||
)
|
||||
model.set_state_dict(new_state_dict)
|
||||
logger.info(f"loaded pretrained_model successful from {pm}")
|
||||
return {}
|
||||
|
||||
|
||||
def load_pretrained_params(model, path):
|
||||
if path is None:
|
||||
return False
|
||||
if not os.path.exists(path) and not os.path.exists(path + ".pdparams"):
|
||||
print(f"The pretrained_model {path} does not exists!")
|
||||
return False
|
||||
logger = get_logger()
|
||||
if path.endswith('pdparams'):
|
||||
path = path.replace('.pdparams', '')
|
||||
assert os.path.exists(path + ".pdparams"), \
|
||||
f"The {path}.pdparams does not exists!"
|
||||
|
||||
path = path if path.endswith('.pdparams') else path + '.pdparams'
|
||||
params = paddle.load(path)
|
||||
params = paddle.load(path + '.pdparams')
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||
new_state_dict[k1] = params[k2]
|
||||
else:
|
||||
print(
|
||||
logger.info(
|
||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||
)
|
||||
model.set_state_dict(new_state_dict)
|
||||
print(f"load pretrain successful from {path}")
|
||||
logger.info(f"load pretrain successful from {path}")
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
|
|||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.metrics import build_metric
|
||||
from ppocr.utils.save_load import init_model, load_dygraph_params
|
||||
from ppocr.utils.save_load import load_model
|
||||
from ppocr.utils.utility import print_dict
|
||||
import tools.program as program
|
||||
|
||||
|
@ -60,7 +60,7 @@ def main():
|
|||
else:
|
||||
model_type = None
|
||||
|
||||
best_model_dict = load_dygraph_params(config, model, logger, None)
|
||||
best_model_dict = load_model(config, model)
|
||||
if len(best_model_dict):
|
||||
logger.info('metric in ckpt ***************')
|
||||
for k, v in best_model_dict.items():
|
||||
|
|
|
@ -27,7 +27,7 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
|||
from ppocr.data import build_dataloader
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.save_load import init_model, load_dygraph_params
|
||||
from ppocr.utils.save_load import load_model
|
||||
from ppocr.utils.utility import print_dict
|
||||
import tools.program as program
|
||||
|
||||
|
@ -57,7 +57,7 @@ def main():
|
|||
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
best_model_dict = load_dygraph_params(config, model, logger, None)
|
||||
best_model_dict = load_model(config, model)
|
||||
if len(best_model_dict):
|
||||
logger.info('metric in ckpt ***************')
|
||||
for k, v in best_model_dict.items():
|
||||
|
|
|
@ -26,7 +26,7 @@ from paddle.jit import to_static
|
|||
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import load_model
|
||||
from ppocr.utils.logging import get_logger
|
||||
from tools.program import load_config, merge_config, ArgsParser
|
||||
|
||||
|
@ -107,7 +107,7 @@ def main():
|
|||
else: # base rec model
|
||||
config["Architecture"]["Head"]["out_channels"] = char_num
|
||||
model = build_model(config["Architecture"])
|
||||
init_model(config, model)
|
||||
load_model(config, model)
|
||||
model.eval()
|
||||
|
||||
save_path = config["Global"]["save_inference_dir"]
|
||||
|
|
|
@ -32,7 +32,7 @@ import paddle
|
|||
from ppocr.data import create_operators, transform
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import load_model
|
||||
from ppocr.utils.utility import get_image_file_list
|
||||
import tools.program as program
|
||||
|
||||
|
@ -47,7 +47,7 @@ def main():
|
|||
# build model
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
init_model(config, model)
|
||||
load_model(config, model)
|
||||
|
||||
# create data ops
|
||||
transforms = []
|
||||
|
|
|
@ -34,7 +34,7 @@ import paddle
|
|||
from ppocr.data import create_operators, transform
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.save_load import init_model, load_dygraph_params
|
||||
from ppocr.utils.save_load import load_model
|
||||
from ppocr.utils.utility import get_image_file_list
|
||||
import tools.program as program
|
||||
|
||||
|
@ -59,7 +59,7 @@ def main():
|
|||
# build model
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
_ = load_dygraph_params(config, model, logger, None)
|
||||
load_model(config, model)
|
||||
# build post process
|
||||
post_process_class = build_post_process(config['PostProcess'])
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ import paddle
|
|||
from ppocr.data import create_operators, transform
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import load_model
|
||||
from ppocr.utils.utility import get_image_file_list
|
||||
import tools.program as program
|
||||
|
||||
|
@ -68,7 +68,7 @@ def main():
|
|||
# build model
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
init_model(config, model)
|
||||
load_model(config, model)
|
||||
|
||||
# build post process
|
||||
post_process_class = build_post_process(config['PostProcess'],
|
||||
|
|
|
@ -33,7 +33,7 @@ import paddle
|
|||
from ppocr.data import create_operators, transform
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import load_model
|
||||
from ppocr.utils.utility import get_image_file_list
|
||||
import tools.program as program
|
||||
|
||||
|
@ -58,7 +58,7 @@ def main():
|
|||
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
init_model(config, model)
|
||||
load_model(config, model)
|
||||
|
||||
# create data ops
|
||||
transforms = []
|
||||
|
@ -75,9 +75,7 @@ def main():
|
|||
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
|
||||
]
|
||||
elif config['Architecture']['algorithm'] == "SAR":
|
||||
op[op_name]['keep_keys'] = [
|
||||
'image', 'valid_ratio'
|
||||
]
|
||||
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
|
||||
else:
|
||||
op[op_name]['keep_keys'] = ['image']
|
||||
transforms.append(op)
|
||||
|
|
|
@ -34,11 +34,12 @@ from paddle.jit import to_static
|
|||
from ppocr.data import create_operators, transform
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import load_model
|
||||
from ppocr.utils.utility import get_image_file_list
|
||||
import tools.program as program
|
||||
import cv2
|
||||
|
||||
|
||||
def main(config, device, logger, vdl_writer):
|
||||
global_config = config['Global']
|
||||
|
||||
|
@ -53,7 +54,7 @@ def main(config, device, logger, vdl_writer):
|
|||
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
init_model(config, model, logger)
|
||||
load_model(config, model)
|
||||
|
||||
# create data ops
|
||||
transforms = []
|
||||
|
@ -104,4 +105,3 @@ def main(config, device, logger, vdl_writer):
|
|||
if __name__ == '__main__':
|
||||
config, device, logger, vdl_writer = program.preprocess()
|
||||
main(config, device, logger, vdl_writer)
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ from ppocr.losses import build_loss
|
|||
from ppocr.optimizer import build_optimizer
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.metrics import build_metric
|
||||
from ppocr.utils.save_load import init_model, load_dygraph_params
|
||||
from ppocr.utils.save_load import load_model
|
||||
import tools.program as program
|
||||
|
||||
dist.get_world_size()
|
||||
|
@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
|
|||
# build metric
|
||||
eval_class = build_metric(config['Metric'])
|
||||
# load pretrain model
|
||||
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
|
||||
pre_best_model_dict = load_model(config, model, optimizer)
|
||||
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
|
||||
if valid_dataloader is not None:
|
||||
logger.info('valid dataloader has {} iters'.format(
|
||||
|
|
Loading…
Reference in New Issue