From 721ac0bf618c794e90b8b026f0092336b190161e Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Thu, 23 Dec 2021 08:57:46 +0000 Subject: [PATCH] refactor: simplify code 1. remove WrapLayer and wrap_theseus; 2. support call update_res() one more; 3. optim parse_pattern_str() to return list of layer parsed. --- ppcls/arch/backbone/base/theseus_layer.py | 196 ++++++++-------------- 1 file changed, 72 insertions(+), 124 deletions(-) diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index fb06c183b..c11ac6b36 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -23,6 +23,7 @@ class TheseusLayer(nn.Layer): res_dict = {"output": output} # 'list' is needed to avoid error raised by popping self.res_dict for res_key in list(self.res_dict): + # clear the res_dict because the forward process may change according to input res_dict[res_key] = self.res_dict.pop(res_key) return res_dict @@ -30,7 +31,7 @@ class TheseusLayer(nn.Layer): self.res_dict[self.res_name] = output def replace_sub(self, *args, **kwargs) -> None: - msg = "\"replace_sub\" is deprecated, please use \"layer_wrench\" instead." + msg = "The function 'replace_sub()' is deprecated, please use 'layer_wrench()' instead." logger.error(DeprecationWarning(msg)) raise DeprecationWarning(msg) @@ -43,20 +44,20 @@ class TheseusLayer(nn.Layer): Args: layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'. - handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'. + handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'. The formal params are the layer(nn.Layer) and pattern(str) that is (a member of) layer_name_pattern (when layer_name_pattern is List type). And the return is the layer processed. Returns: - Dict[str, nn.Layer]: The key is the patter and corresponding value is the result returned by 'handle_func'. - + Dict[str, nn.Layer]: The key is the pattern and corresponding value is the result returned by 'handle_func()'. + Examples: from paddle import nn import paddleclas - def rep_func(sub_layer: nn.Layer, pattern: str): + def rep_func(layer: nn.Layer, pattern: str): new_layer = nn.Conv2D( - in_channels=sub_layer._in_channels, - out_channels=sub_layer._out_channels, + in_channels=layer._in_channels, + out_channels=layer._out_channels, kernel_size=5, padding=2 ) @@ -73,25 +74,16 @@ class TheseusLayer(nn.Layer): handle_res_dict = {} for pattern in layer_name_pattern: - # find parent layer of sub-layer specified by pattern - sub_layer_parent = None - for target_layer_dict in parse_pattern_str( - pattern=pattern, idx=(0, -1), parent_layer=self): - sub_layer_parent = target_layer_dict["target_layer"] - - if not sub_layer_parent: + # parse pattern to find target layer and its parent + layer_list = parse_pattern_str(pattern=pattern, parent_layer=self) + if not layer_list: continue + sub_layer_parent = layer_list[-2]["layer"] if len( + layer_list) > 1 else self - # find sub-layer specified by pattern - sub_layer = None - for target_layer_dict in parse_pattern_str( - pattern=pattern, idx=-1, parent_layer=sub_layer_parent): - sub_layer = target_layer_dict["target_layer"] - sub_layer_name = target_layer_dict["target_layer_name"] - sub_layer_index = target_layer_dict["target_layer_index"] - - if not sub_layer: - continue + sub_layer = layer_list[-1]["layer"] + sub_layer_name = layer_list[-1]["name"] + sub_layer_index = layer_list[-1]["index"] new_sub_layer = handle_func(sub_layer, pattern) @@ -114,65 +106,60 @@ class TheseusLayer(nn.Layer): bool: 'True' if successful, 'False' otherwise. """ - to_identity_list = [] + layer_list = parse_pattern_str(stop_layer_name, self) + if not layer_list: + return False - for target_layer_dict in parse_pattern_str(stop_layer_name, self): - sub_layer_name = target_layer_dict["target_layer_name"] - sub_layer_index = target_layer_dict["target_layer_index"] - parent_layer = target_layer_dict["parent_layer"] - to_identity_list.append( - (parent_layer, sub_layer_name, sub_layer_index)) - - for to_identity_layer in to_identity_list: - if not set_identity(*to_identity_layer): - msg = "Failed to set the layers that after stop_layer_name to IdentityLayer." + parent_layer = self + for layer_dict in layer_list: + name, index = layer_dict["name"], layer_dict["index"] + if not set_identity(parent_layer, name, index): + msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'." logger.warning(msg) return False + parent_layer = layer_dict["layer"] + return True - def update_res(self, - return_patterns: Union[str, List[str]]) -> Dict[str, bool]: - """update the results to be returned. + def update_res( + self, + return_patterns: Union[str, List[str]]) -> Dict[str, nn.Layer]: + """update the result(s) to be returned. Args: return_patterns (Union[str, List[str]]): The name of layer to return output. Returns: - Dict[str, bool]: The pattern(str) is be set successfully if 'True'(bool), failed if 'False'(bool). + Dict[str, nn.Layer]: The pattern(str) and corresponding layer(nn.Layer) that have been set successfully. """ + # clear res_dict that could have been set + self.res_dict = {} + class Handler(object): def __init__(self, res_dict): + # res_dict is a reference self.res_dict = res_dict def __call__(self, layer, pattern): layer.res_dict = self.res_dict layer.res_name = pattern - layer.register_forward_post_hook(layer._save_sub_res_hook) + if hasattr(layer, "hook_remove_helper"): + layer.hook_remove_helper.remove() + layer.hook_remove_helper = layer.register_forward_post_hook( + layer._save_sub_res_hook) return layer handle_func = Handler(self.res_dict) - return self.replace_sub(return_patterns, handle_func=handle_func) + res_dict = self.layer_wrench(return_patterns, handle_func=handle_func) + if hasattr(self, "hook_remove_helper"): + self.hook_remove_helper.remove() + self.hook_remove_helper = self.register_forward_post_hook( + self._return_dict_hook) -class WrapLayer(TheseusLayer): - def __init__(self, sub_layer): - super(WrapLayer, self).__init__() - self.sub_layer = sub_layer - - def forward(self, *inputs, **kwargs): - return self.sub_layer(*inputs, **kwargs) - - -def wrap_theseus(sub_layer): - return WrapLayer(sub_layer) - - -def unwrap_theseus(sub_layer): - if isinstance(sub_layer, WrapLayer): - sub_layer = sub_layer.sub_layer - return sub_layer + return res_dict def set_identity(parent_layer: nn.Layer, @@ -211,58 +198,30 @@ def set_identity(parent_layer: nn.Layer, return stop_after -def slice_pattern(pattern: str, idx: Union[Tuple, int]=None) -> List: - """slice the string type "pattern" to list type by separator ".". - - Args: - pattern (str): The pattern to discribe layer name. - idx (Union[Tuple, int], optional): The index(s) of sub-list of list sliced. Defaults to None. - - Returns: - List: The sub-list of list sliced by "pattern". - """ - - pattern_list = pattern.split(".") - if idx: - if isinstance(idx, Tuple): - if len(idx) == 1: - return pattern_list[idx[0]] - elif len(idx) == 2: - return pattern_list[idx[0]:idx[1]] - else: - msg = f"Only support length of 'idx' is 1 or 2 when 'idx' is a Tuple." - logger.warning(msg) - return None - elif isinstance(idx, int): - return [pattern_list[idx]] - else: - msg = f"Only support type of 'idx' is int or Tuple." - logger.warning(msg) - return None - - return pattern_list - - -def parse_pattern_str(pattern: str, parent_layer: nn.Layer, - idx=None) -> Dict[str, Union[nn.Layer, None, str]]: +def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[ + None, List[Dict[str, Union[nn.Layer, str, None]]]]: """parse the string type pattern. Args: - pattern (str): The pattern to discribe layer name. - parent_layer (nn.Layer): The parent layer of target layer(s) specified by "pattern". - idx ([type], optional): [description]. The index(s) of sub-list of list sliced. Defaults to None. + pattern (str): The pattern to discribe layer. + parent_layer (nn.Layer): The root layer relative to the pattern. Returns: - Dict[str, Union[nn.Layer, None, str]]: Dict["target_layer": Union[nn.Layer, None], "target_layer_name": str, "target_layer_index": str, "parent_layer": nn.Layer] - - Yields: - Iterator[Dict[str, Union[nn.Layer, None, str]]]: Dict["target_layer": Union[nn.Layer, None], "target_layer_name": str, "target_layer_index": str, "parent_layer": nn.Layer] + Union[None, List[Dict[str, Union[nn.Layer, str, None]]]]: None if failed. If successfully, the members are layers parsed in order: + [ + {"layer": first layer, "name": first layer's name parsed, "index": first layer's index parsed if exist}, + {"layer": second layer, "name": second layer's name parsed, "index": second layer's index parsed if exist}, + ... + ] """ - pattern_list = slice_pattern(pattern, idx) + pattern_list = pattern.split(".") if not pattern_list: - return None, None, None + msg = f"The pattern('{pattern}') is illegal. Please check and retry." + logger.warning(msg) + return None + layer_list = [] while len(pattern_list) > 0: if '[' in pattern_list[0]: target_layer_name = pattern_list[0].split('[')[0] @@ -272,38 +231,27 @@ def parse_pattern_str(pattern: str, parent_layer: nn.Layer, target_layer_index = None target_layer = getattr(parent_layer, target_layer_name, None) - target_layer = unwrap_theseus(target_layer) if target_layer is None: - msg = f"Not found layer named({target_layer_name}) specifed in pattern({pattern})." + msg = f"Not found layer named('{target_layer_name}') specifed in pattern('{pattern}')." logger.warning(msg) - return { - "target_layer": None, - "target_layer_name": target_layer_name, - "target_layer_index": target_layer_index, - "parent_layer": parent_layer - } + return None if target_layer_index and target_layer: if int(target_layer_index) < 0 or int(target_layer_index) >= len( target_layer): - msg = f"Not found layer by index({target_layer_index}) specifed in pattern({pattern}). The lenght of sub_layer's parent layer is < '{len(parent_layer)}' and > '0'." + msg = f"Not found layer by index('{target_layer_index}') specifed in pattern('{pattern}'). The index should < {len(target_layer)} and > 0." logger.warning(msg) - return { - "target_layer": None, - "target_layer_name": target_layer_name, - "target_layer_index": target_layer_index, - "parent_layer": parent_layer - } - target_layer = target_layer[target_layer_index] - target_layer = unwrap_theseus(target_layer) + return None - yield { - "target_layer": target_layer, - "target_layer_name": target_layer_name, - "target_layer_index": target_layer_index, - "parent_layer": parent_layer - } + target_layer = target_layer[target_layer_index] + + layer_list.append({ + "layer": target_layer, + "name": target_layer_name, + "index": target_layer_index + }) pattern_list = pattern_list[1:] parent_layer = target_layer + return layer_list