From ae4167dc32fe4e35dffc90b48a80ee07cd1dc00f Mon Sep 17 00:00:00 2001
From: zhoujun <zjwenmu@gmail.com>
Date: Fri, 12 Nov 2021 11:06:36 +0800
Subject: [PATCH] merge init_model and load_dygraph_params to load_model
 (#4623)

* merge init_model and load_dygraph_params to load_model
---
 deploy/slim/prune/export_prune_model.py       |  4 +-
 deploy/slim/prune/sensitivity_anal.py         |  4 +-
 deploy/slim/quantization/export_model.py      |  4 +-
 deploy/slim/quantization/quant.py             |  4 +-
 deploy/slim/quantization/quant_kl.py          |  2 +-
 .../architectures/distillation_model.py       |  2 +-
 ppocr/utils/save_load.py                      | 74 +++++--------------
 tools/eval.py                                 |  4 +-
 tools/export_center.py                        |  4 +-
 tools/export_model.py                         |  4 +-
 tools/infer_cls.py                            |  4 +-
 tools/infer_det.py                            |  4 +-
 tools/infer_e2e.py                            |  4 +-
 tools/infer_rec.py                            |  8 +-
 tools/infer_table.py                          |  6 +-
 tools/train.py                                |  4 +-
 16 files changed, 47 insertions(+), 89 deletions(-)

diff --git a/deploy/slim/prune/export_prune_model.py b/deploy/slim/prune/export_prune_model.py
index 29f7d211d..2c9d0a183 100644
--- a/deploy/slim/prune/export_prune_model.py
+++ b/deploy/slim/prune/export_prune_model.py
@@ -30,7 +30,7 @@ from ppocr.modeling.architectures import build_model
 
 from ppocr.postprocess import build_post_process
 from ppocr.metrics import build_metric
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import load_model
 import tools.program as program
 
 
@@ -89,7 +89,7 @@ def main(config, device, logger, vdl_writer):
     logger.info(f"FLOPs after pruning: {flops}")
 
     # load pretrain model
-    pre_best_model_dict = init_model(config, model, logger, None)
+    load_model(config, model)
     metric = program.eval(model, valid_dataloader, post_process_class,
                           eval_class)
     logger.info(f"metric['hmean']: {metric['hmean']}")
diff --git a/deploy/slim/prune/sensitivity_anal.py b/deploy/slim/prune/sensitivity_anal.py
index 0f0492af2..c5d008779 100644
--- a/deploy/slim/prune/sensitivity_anal.py
+++ b/deploy/slim/prune/sensitivity_anal.py
@@ -32,7 +32,7 @@ from ppocr.losses import build_loss
 from ppocr.optimizer import build_optimizer
 from ppocr.postprocess import build_post_process
 from ppocr.metrics import build_metric
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import load_model
 import tools.program as program
 
 dist.get_world_size()
@@ -94,7 +94,7 @@ def main(config, device, logger, vdl_writer):
     # build metric
     eval_class = build_metric(config['Metric'])
     # load pretrain model
-    pre_best_model_dict = init_model(config, model, logger, optimizer)
+    pre_best_model_dict = load_model(config, model, optimizer)
 
     logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
                 format(len(train_dataloader), len(valid_dataloader)))
diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py
index d94e53034..dddae923d 100755
--- a/deploy/slim/quantization/export_model.py
+++ b/deploy/slim/quantization/export_model.py
@@ -28,7 +28,7 @@ from paddle.jit import to_static
 
 from ppocr.modeling.architectures import build_model
 from ppocr.postprocess import build_post_process
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import load_model
 from ppocr.utils.logging import get_logger
 from tools.program import load_config, merge_config, ArgsParser
 from ppocr.metrics import build_metric
@@ -101,7 +101,7 @@ def main():
     quanter = QAT(config=quant_config)
     quanter.quantize(model)
 
-    init_model(config, model)
+    load_model(config, model)
     model.eval()
 
     # build metric
diff --git a/deploy/slim/quantization/quant.py b/deploy/slim/quantization/quant.py
index 37aab68a0..941cfb36b 100755
--- a/deploy/slim/quantization/quant.py
+++ b/deploy/slim/quantization/quant.py
@@ -37,7 +37,7 @@ from ppocr.losses import build_loss
 from ppocr.optimizer import build_optimizer
 from ppocr.postprocess import build_post_process
 from ppocr.metrics import build_metric
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import load_model
 import tools.program as program
 from paddleslim.dygraph.quant import QAT
 
@@ -137,7 +137,7 @@ def main(config, device, logger, vdl_writer):
     # build metric
     eval_class = build_metric(config['Metric'])
     # load pretrain model
-    pre_best_model_dict = init_model(config, model, logger, optimizer)
+    pre_best_model_dict = load_model(config, model, optimizer)
 
     logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
                 format(len(train_dataloader), len(valid_dataloader)))
diff --git a/deploy/slim/quantization/quant_kl.py b/deploy/slim/quantization/quant_kl.py
index d866784ae..cc3a455b9 100755
--- a/deploy/slim/quantization/quant_kl.py
+++ b/deploy/slim/quantization/quant_kl.py
@@ -37,7 +37,7 @@ from ppocr.losses import build_loss
 from ppocr.optimizer import build_optimizer
 from ppocr.postprocess import build_post_process
 from ppocr.metrics import build_metric
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import load_model
 import tools.program as program
 import paddleslim
 from paddleslim.dygraph.quant import QAT
diff --git a/ppocr/modeling/architectures/distillation_model.py b/ppocr/modeling/architectures/distillation_model.py
index 1e95fe574..5e867940e 100644
--- a/ppocr/modeling/architectures/distillation_model.py
+++ b/ppocr/modeling/architectures/distillation_model.py
@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
 from ppocr.modeling.necks import build_neck
 from ppocr.modeling.heads import build_head
 from .base_model import BaseModel
-from ppocr.utils.save_load import init_model, load_pretrained_params
+from ppocr.utils.save_load import load_pretrained_params
 
 __all__ = ['DistillationModel']
 
diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py
index a7d24dd71..702f3e977 100644
--- a/ppocr/utils/save_load.py
+++ b/ppocr/utils/save_load.py
@@ -25,7 +25,7 @@ import paddle
 
 from ppocr.utils.logging import get_logger
 
-__all__ = ['init_model', 'save_model', 'load_dygraph_params']
+__all__ = ['load_model']
 
 
 def _mkdir_if_not_exist(path, logger):
@@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger):
                 raise OSError('Failed to mkdir {}'.format(path))
 
 
-def init_model(config, model, optimizer=None, lr_scheduler=None):
+def load_model(config, model, optimizer=None):
     """
     load model from checkpoint or pretrained_model
     """
@@ -54,15 +54,14 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
     pretrained_model = global_config.get('pretrained_model')
     best_model_dict = {}
     if checkpoints:
-        assert os.path.exists(checkpoints + ".pdparams"), \
-            "Given dir {}.pdparams not exist.".format(checkpoints)
+        if checkpoints.endswith('pdparams'):
+            checkpoints = checkpoints.replace('.pdparams', '')
         assert os.path.exists(checkpoints + ".pdopt"), \
-            "Given dir {}.pdopt not exist.".format(checkpoints)
-        para_dict = paddle.load(checkpoints + '.pdparams')
-        opti_dict = paddle.load(checkpoints + '.pdopt')
-        model.set_state_dict(para_dict)
+            f"The {checkpoints}.pdopt does not exists!"
+        load_pretrained_params(model, checkpoints)
+        optim_dict = paddle.load(checkpoints + '.pdopt')
         if optimizer is not None:
-            optimizer.set_state_dict(opti_dict)
+            optimizer.set_state_dict(optim_dict)
 
         if os.path.exists(checkpoints + '.states'):
             with open(checkpoints + '.states', 'rb') as f:
@@ -73,70 +72,31 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
                 best_model_dict['start_epoch'] = states_dict['epoch'] + 1
         logger.info("resume from {}".format(checkpoints))
     elif pretrained_model:
-        if not isinstance(pretrained_model, list):
-            pretrained_model = [pretrained_model]
-        for pretrained in pretrained_model:
-            if not (os.path.isdir(pretrained) or
-                    os.path.exists(pretrained + '.pdparams')):
-                raise ValueError("Model pretrain path {} does not "
-                                 "exists.".format(pretrained))
-            param_state_dict = paddle.load(pretrained + '.pdparams')
-            model.set_state_dict(param_state_dict)
-            logger.info("load pretrained model from {}".format(
-                pretrained_model))
+        load_pretrained_params(model, pretrained_model)
     else:
         logger.info('train from scratch')
     return best_model_dict
 
 
-def load_dygraph_params(config, model, logger, optimizer):
-    ckp = config['Global']['checkpoints']
-    if ckp and os.path.exists(ckp + ".pdparams"):
-        pre_best_model_dict = init_model(config, model, optimizer)
-        return pre_best_model_dict
-    else:
-        pm = config['Global']['pretrained_model']
-        if pm is None:
-            return {}
-        if not os.path.exists(pm) and not os.path.exists(pm + ".pdparams"):
-            logger.info(f"The pretrained_model {pm} does not exists!")
-            return {}
-        pm = pm if pm.endswith('.pdparams') else pm + '.pdparams'
-        params = paddle.load(pm)
-        state_dict = model.state_dict()
-        new_state_dict = {}
-        for k1, k2 in zip(state_dict.keys(), params.keys()):
-            if list(state_dict[k1].shape) == list(params[k2].shape):
-                new_state_dict[k1] = params[k2]
-            else:
-                logger.info(
-                    f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
-                )
-        model.set_state_dict(new_state_dict)
-        logger.info(f"loaded pretrained_model successful from {pm}")
-        return {}
-
-
 def load_pretrained_params(model, path):
-    if path is None:
-        return False
-    if not os.path.exists(path) and not os.path.exists(path + ".pdparams"):
-        print(f"The pretrained_model {path} does not exists!")
-        return False
+    logger = get_logger()
+    if path.endswith('pdparams'):
+        path = path.replace('.pdparams', '')
+    assert os.path.exists(path + ".pdparams"), \
+        f"The {path}.pdparams does not exists!"
 
-    path = path if path.endswith('.pdparams') else path + '.pdparams'
-    params = paddle.load(path)
+    params = paddle.load(path + '.pdparams')
     state_dict = model.state_dict()
     new_state_dict = {}
     for k1, k2 in zip(state_dict.keys(), params.keys()):
         if list(state_dict[k1].shape) == list(params[k2].shape):
             new_state_dict[k1] = params[k2]
         else:
-            print(
+            logger.info(
                 f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
             )
     model.set_state_dict(new_state_dict)
-    print(f"load pretrain successful from {path}")
+    logger.info(f"load pretrain successful from {path}")
     return model
 
 
diff --git a/tools/eval.py b/tools/eval.py
index 28247bc57..c85490a31 100755
--- a/tools/eval.py
+++ b/tools/eval.py
@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
 from ppocr.modeling.architectures import build_model
 from ppocr.postprocess import build_post_process
 from ppocr.metrics import build_metric
-from ppocr.utils.save_load import init_model, load_dygraph_params
+from ppocr.utils.save_load import load_model
 from ppocr.utils.utility import print_dict
 import tools.program as program
 
@@ -60,7 +60,7 @@ def main():
     else:
         model_type = None
 
-    best_model_dict = load_dygraph_params(config, model, logger, None)
+    best_model_dict = load_model(config, model)
     if len(best_model_dict):
         logger.info('metric in ckpt ***************')
         for k, v in best_model_dict.items():
diff --git a/tools/export_center.py b/tools/export_center.py
index c46e8b9d5..30b9c3349 100644
--- a/tools/export_center.py
+++ b/tools/export_center.py
@@ -27,7 +27,7 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
 from ppocr.data import build_dataloader
 from ppocr.modeling.architectures import build_model
 from ppocr.postprocess import build_post_process
-from ppocr.utils.save_load import init_model, load_dygraph_params
+from ppocr.utils.save_load import load_model
 from ppocr.utils.utility import print_dict
 import tools.program as program
 
@@ -57,7 +57,7 @@ def main():
 
     model = build_model(config['Architecture'])
 
-    best_model_dict = load_dygraph_params(config, model, logger, None)
+    best_model_dict = load_model(config, model)
     if len(best_model_dict):
         logger.info('metric in ckpt ***************')
         for k, v in best_model_dict.items():
diff --git a/tools/export_model.py b/tools/export_model.py
index 64a0d4036..9ed8e1b6a 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -26,7 +26,7 @@ from paddle.jit import to_static
 
 from ppocr.modeling.architectures import build_model
 from ppocr.postprocess import build_post_process
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import load_model
 from ppocr.utils.logging import get_logger
 from tools.program import load_config, merge_config, ArgsParser
 
@@ -107,7 +107,7 @@ def main():
         else:  # base rec model
             config["Architecture"]["Head"]["out_channels"] = char_num
     model = build_model(config["Architecture"])
-    init_model(config, model)
+    load_model(config, model)
     model.eval()
 
     save_path = config["Global"]["save_inference_dir"]
diff --git a/tools/infer_cls.py b/tools/infer_cls.py
index a588cab43..7522e4390 100755
--- a/tools/infer_cls.py
+++ b/tools/infer_cls.py
@@ -32,7 +32,7 @@ import paddle
 from ppocr.data import create_operators, transform
 from ppocr.modeling.architectures import build_model
 from ppocr.postprocess import build_post_process
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import load_model
 from ppocr.utils.utility import get_image_file_list
 import tools.program as program
 
@@ -47,7 +47,7 @@ def main():
     # build model
     model = build_model(config['Architecture'])
 
-    init_model(config, model)
+    load_model(config, model)
 
     # create data ops
     transforms = []
diff --git a/tools/infer_det.py b/tools/infer_det.py
index ce16da8dc..bb2cca736 100755
--- a/tools/infer_det.py
+++ b/tools/infer_det.py
@@ -34,7 +34,7 @@ import paddle
 from ppocr.data import create_operators, transform
 from ppocr.modeling.architectures import build_model
 from ppocr.postprocess import build_post_process
-from ppocr.utils.save_load import init_model, load_dygraph_params
+from ppocr.utils.save_load import load_model
 from ppocr.utils.utility import get_image_file_list
 import tools.program as program
 
@@ -59,7 +59,7 @@ def main():
     # build model
     model = build_model(config['Architecture'])
 
-    _ = load_dygraph_params(config, model, logger, None)
+    load_model(config, model)
     # build post process
     post_process_class = build_post_process(config['PostProcess'])
 
diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py
index 1cd468b8e..96dbac8e8 100755
--- a/tools/infer_e2e.py
+++ b/tools/infer_e2e.py
@@ -34,7 +34,7 @@ import paddle
 from ppocr.data import create_operators, transform
 from ppocr.modeling.architectures import build_model
 from ppocr.postprocess import build_post_process
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import load_model
 from ppocr.utils.utility import get_image_file_list
 import tools.program as program
 
@@ -68,7 +68,7 @@ def main():
     # build model
     model = build_model(config['Architecture'])
 
-    init_model(config, model)
+    load_model(config, model)
 
     # build post process
     post_process_class = build_post_process(config['PostProcess'],
diff --git a/tools/infer_rec.py b/tools/infer_rec.py
index 29d4b530d..adc3c1c3c 100755
--- a/tools/infer_rec.py
+++ b/tools/infer_rec.py
@@ -33,7 +33,7 @@ import paddle
 from ppocr.data import create_operators, transform
 from ppocr.modeling.architectures import build_model
 from ppocr.postprocess import build_post_process
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import load_model
 from ppocr.utils.utility import get_image_file_list
 import tools.program as program
 
@@ -58,7 +58,7 @@ def main():
 
     model = build_model(config['Architecture'])
 
-    init_model(config, model)
+    load_model(config, model)
 
     # create data ops
     transforms = []
@@ -75,9 +75,7 @@ def main():
                     'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
                 ]
             elif config['Architecture']['algorithm'] == "SAR":
-                op[op_name]['keep_keys'] = [
-                    'image', 'valid_ratio'
-                ]
+                op[op_name]['keep_keys'] = ['image', 'valid_ratio']
             else:
                 op[op_name]['keep_keys'] = ['image']
         transforms.append(op)
diff --git a/tools/infer_table.py b/tools/infer_table.py
index f743d8754..c73e38404 100644
--- a/tools/infer_table.py
+++ b/tools/infer_table.py
@@ -34,11 +34,12 @@ from paddle.jit import to_static
 from ppocr.data import create_operators, transform
 from ppocr.modeling.architectures import build_model
 from ppocr.postprocess import build_post_process
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import load_model
 from ppocr.utils.utility import get_image_file_list
 import tools.program as program
 import cv2
 
+
 def main(config, device, logger, vdl_writer):
     global_config = config['Global']
 
@@ -53,7 +54,7 @@ def main(config, device, logger, vdl_writer):
 
     model = build_model(config['Architecture'])
 
-    init_model(config, model, logger)
+    load_model(config, model)
 
     # create data ops
     transforms = []
@@ -104,4 +105,3 @@ def main(config, device, logger, vdl_writer):
 if __name__ == '__main__':
     config, device, logger, vdl_writer = program.preprocess()
     main(config, device, logger, vdl_writer)
-
diff --git a/tools/train.py b/tools/train.py
index d182af298..f3852469e 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -35,7 +35,7 @@ from ppocr.losses import build_loss
 from ppocr.optimizer import build_optimizer
 from ppocr.postprocess import build_post_process
 from ppocr.metrics import build_metric
-from ppocr.utils.save_load import init_model, load_dygraph_params
+from ppocr.utils.save_load import load_model
 import tools.program as program
 
 dist.get_world_size()
@@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
     # build metric
     eval_class = build_metric(config['Metric'])
     # load pretrain model
-    pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
+    pre_best_model_dict = load_model(config, model, optimizer)
     logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
     if valid_dataloader is not None:
         logger.info('valid dataloader has {} iters'.format(