From 32dc1c1c0d37bba769d4d706215a993204a445bc Mon Sep 17 00:00:00 2001
From: littletomatodonkey <dazhiningsibuqu@163.com>
Date: Wed, 2 Sep 2020 15:15:49 +0000
Subject: [PATCH] improve dygraph model

---
 ppcls/utils/check.py  |  4 +-
 ppcls/utils/config.py | 14 +++++--
 tools/program.py      | 93 ++++++++++++-------------------------------
 tools/run.sh          |  3 +-
 4 files changed, 40 insertions(+), 74 deletions(-)

diff --git a/ppcls/utils/check.py b/ppcls/utils/check.py
index b09a2498c..4716c44dd 100644
--- a/ppcls/utils/check.py
+++ b/ppcls/utils/check.py
@@ -31,12 +31,12 @@ def check_version():
     Log error and exit when the installed version of paddlepaddle is
     not satisfied.
     """
-    err = "PaddlePaddle version 2.0.0 or higher is required, " \
+    err = "PaddlePaddle version 1.8.0 or higher is required, " \
           "or a suitable develop version is satisfied as well. \n" \
           "Please make sure the version is good with your code." \
 
     try:
-        fluid.require_version('2.0.0')
+        fluid.require_version('1.8.0')
     except Exception:
         logger.error(err)
         sys.exit(1)
diff --git a/ppcls/utils/config.py b/ppcls/utils/config.py
index 93b11569e..25a918afe 100644
--- a/ppcls/utils/config.py
+++ b/ppcls/utils/config.py
@@ -64,14 +64,18 @@ def print_dict(d, delimiter=0):
     placeholder = "-" * 60
     for k, v in sorted(d.items()):
         if isinstance(v, dict):
-            logger.info("{}{} : ".format(delimiter * " ", logger.coloring(k, "HEADER")))
+            logger.info("{}{} : ".format(delimiter * " ",
+                                         logger.coloring(k, "HEADER")))
             print_dict(v, delimiter + 4)
         elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
-            logger.info("{}{} : ".format(delimiter * " ", logger.coloring(str(k),"HEADER")))
+            logger.info("{}{} : ".format(delimiter * " ",
+                                         logger.coloring(str(k), "HEADER")))
             for value in v:
                 print_dict(value, delimiter + 4)
         else:
-            logger.info("{}{} : {}".format(delimiter * " ", logger.coloring(k,"HEADER"), logger.coloring(v,"OKGREEN")))
+            logger.info("{}{} : {}".format(delimiter * " ",
+                                           logger.coloring(k, "HEADER"),
+                                           logger.coloring(v, "OKGREEN")))
 
         if k.isupper():
             logger.info(placeholder)
@@ -138,7 +142,9 @@ def override(dl, ks, v):
             override(dl[k], ks[1:], v)
     else:
         if len(ks) == 1:
-            assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
+            # assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
+            if not ks[0] in dl:
+                logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
             dl[ks[0]] = str2num(v)
         else:
             override(dl[ks[0]], ks[1:], v)
diff --git a/tools/program.py b/tools/program.py
index fe69fd1b2..31d4d85c3 100644
--- a/tools/program.py
+++ b/tools/program.py
@@ -35,8 +35,6 @@ from ppcls.utils.misc import AverageMeter
 from ppcls.utils import logger
 
 from paddle.fluid.dygraph.base import to_variable
-from paddle.fluid.incubate.fleet.collective import fleet
-from paddle.fluid.incubate.fleet.collective import DistributedStrategy
 
 
 def create_dataloader():
@@ -243,43 +241,6 @@ def create_optimizer(config, parameter_list=None):
     return opt(lr, parameter_list)
 
 
-def dist_optimizer(config, optimizer):
-    """
-    Create a distributed optimizer based on a normal optimizer
-
-    Args:
-        config(dict):
-        optimizer(): a normal optimizer
-
-    Returns:
-        optimizer: a distributed optimizer
-    """
-    exec_strategy = fluid.ExecutionStrategy()
-    exec_strategy.num_threads = 3
-    exec_strategy.num_iteration_per_drop_scope = 10
-
-    dist_strategy = DistributedStrategy()
-    dist_strategy.nccl_comm_num = 1
-    dist_strategy.fuse_all_reduce_ops = True
-    dist_strategy.exec_strategy = exec_strategy
-    optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)
-
-    return optimizer
-
-
-def mixed_precision_optimizer(config, optimizer):
-    use_fp16 = config.get('use_fp16', False)
-    amp_scale_loss = config.get('amp_scale_loss', 1.0)
-    use_dynamic_loss_scaling = config.get('use_dynamic_loss_scaling', False)
-    if use_fp16:
-        optimizer = fluid.contrib.mixed_precision.decorate(
-            optimizer,
-            init_loss_scaling=amp_scale_loss,
-            use_dynamic_loss_scaling=use_dynamic_loss_scaling)
-
-    return optimizer
-
-
 def create_feeds(batch, use_mix):
     image = batch[0]
     if use_mix:
@@ -307,26 +268,22 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'):
 
     Returns:
     """
+    print_interval = config.get("print_interval", 10)
     use_mix = config.get("use_mix", False) and mode == "train"
-    if use_mix:
-        metric_list = OrderedDict([
-            ("loss", AverageMeter('loss', '7.4f')),
-            ("lr", AverageMeter(
-                'lr', 'f', need_avg=False)),
-            ("batch_time", AverageMeter('elapse', '.3f')),
-            ('reader_time', AverageMeter('reader', '.3f')),
-        ])
-    else:
+
+    metric_list = [
+        ("loss", AverageMeter('loss', '7.4f')),
+        ("lr", AverageMeter(
+            'lr', 'f', need_avg=False)),
+        ("batch_time", AverageMeter('elapse', '.3f')),
+        ('reader_time', AverageMeter('reader', '.3f')),
+    ]
+    if not use_mix:
         topk_name = 'top{}'.format(config.topk)
-        metric_list = OrderedDict([
-            ("loss", AverageMeter('loss', '7.4f')),
-            ("top1", AverageMeter('top1', '.4f')),
-            (topk_name, AverageMeter(topk_name, '.4f')),
-            ("lr", AverageMeter(
-                'lr', 'f', need_avg=False)),
-            ("batch_time", AverageMeter('elapse', '.3f')),
-            ('reader_time', AverageMeter('reader', '.3f')),
-        ])
+        metric_list.insert(1, (topk_name, AverageMeter(topk_name, '.4f')))
+        metric_list.insert(1, ("top1", AverageMeter("top1", '.4f')))
+
+    metric_list = OrderedDict(metric_list)
 
     tic = time.time()
     for idx, batch in enumerate(dataloader()):
@@ -354,17 +311,19 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'):
         tic = time.time()
 
         fetchs_str = ' '.join([str(m.value) for m in metric_list.values()])
-        if mode == 'eval':
-            logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
-        else:
-            epoch_str = "epoch:{:<3d}".format(epoch)
-            step_str = "{:s} step:{:<4d}".format(mode, idx)
 
-            logger.info("{:s} {:s} {:s}s".format(
-                logger.coloring(epoch_str, "HEADER")
-                if idx == 0 else epoch_str,
-                logger.coloring(step_str, "PURPLE"),
-                logger.coloring(fetchs_str, 'OKGREEN')))
+        if idx % print_interval == 0:
+            if mode == 'eval':
+                logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx,
+                                                            fetchs_str))
+            else:
+                epoch_str = "epoch:{:<3d}".format(epoch)
+                step_str = "{:s} step:{:<4d}".format(mode, idx)
+                logger.info("{:s} {:s} {:s}s".format(
+                    logger.coloring(epoch_str, "HEADER")
+                    if idx == 0 else epoch_str,
+                    logger.coloring(step_str, "PURPLE"),
+                    logger.coloring(fetchs_str, 'OKGREEN')))
 
     end_str = ' '.join([str(m.mean) for m in metric_list.values()] +
                        [metric_list['batch_time'].total])
diff --git a/tools/run.sh b/tools/run.sh
index 55f2918d9..ce9772eb8 100755
--- a/tools/run.sh
+++ b/tools/run.sh
@@ -5,4 +5,5 @@ export PYTHONPATH=$PWD:$PYTHONPATH
 python -m paddle.distributed.launch \
     --selected_gpus="0,1,2,3" \
     tools/train.py \
-        -c ./configs/ResNet/ResNet50.yaml
+        -c ./configs/ResNet/ResNet50_vd.yaml \
+        -o print_interval=10