From 80b8ca3f23089dd88dd6b35b8309e3a459c72f1a Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Wed, 20 Apr 2022 12:16:40 +0800
Subject: [PATCH] fix optimizer/init.py

---
 ppcls/optimizer/__init__.py | 24 +++++++++++++++---------
 1 file changed, 15 insertions(+), 9 deletions(-)

diff --git a/ppcls/optimizer/__init__.py b/ppcls/optimizer/__init__.py
index cc0041137..b7b4d4210 100644
--- a/ppcls/optimizer/__init__.py
+++ b/ppcls/optimizer/__init__.py
@@ -51,7 +51,7 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
         optim_name = optim_config.pop("name")
         optim_config: List[Dict[str, Dict]] = [{
             optim_name: {
-                'scope': config["Arch"].get("name"),
+                'scope': "all",
                 **
                 optim_config
             }
@@ -59,10 +59,10 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
     optim_list = []
     lr_list = []
     for optim_item in optim_config:
-        # optim_cfg = {optim_name1: {scope: xxx, **optim_cfg}}
+        # optim_cfg = {optim_name: {scope: xxx, **optim_cfg}}
         # step1 build lr
-        optim_name = list(optim_item.keys())[0]  # get optim_name1
-        optim_scope = optim_item[optim_name].pop('scope')  # get scope
+        optim_name = list(optim_item.keys())[0]  # get optim_name
+        optim_scope = optim_item[optim_name].pop('scope')  # get optim_scope
         optim_cfg = optim_item[optim_name]  # get optim_cfg
 
         lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch)
@@ -78,7 +78,8 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
             reg_name = reg_config.pop('name') + 'Decay'
             reg = getattr(paddle.regularizer, reg_name)(**reg_config)
             optim_cfg["weight_decay"] = reg
-            logger.debug("build regularizer ({}) success..".format(reg))
+            logger.debug("build regularizer ({}) for scope ({}) success..".
+                         format(reg, optim_scope))
         # step3 build optimizer
         if 'clip_norm' in optim_cfg:
             clip_norm = optim_cfg.pop('clip_norm')
@@ -87,11 +88,16 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
             grad_clip = None
         optim_model = []
         for i in range(len(model_list)):
-            class_name = model_list[i].__class__.__name__
-            if class_name == optim_scope:
+            if len(model_list[i].parameters()) == 0:
+                continue
+            if optim_scope == "all":
                 optim_model.append(model_list[i])
-        assert len(optim_model) == 1 and len(optim_model[0].parameters()) > 0, \
-            f"Invalid optim model for optim scope({optim_scope}), number of optim_model={len(optim_model)}, and number of optim_model's params={len(optim_model[0].parameters())}"
+            else:
+                for m in model_list[i].sublayers(True):
+                    if m.__class__.__name__ == optim_scope:
+                        optim_model.append(model_list[i])
+        assert len(optim_model) == 1, \
+            "Invalid optim model for optim scope({}), number of optim_model={}".format(optim_scope, len(optim_model))
         optim = getattr(optimizer, optim_name)(
             learning_rate=lr, grad_clip=grad_clip,
             **optim_cfg)(model_list=optim_model)