refactor: simplify code

1. remove WrapLayer and wrap_theseus;
2. support call update_res() one more;
3. optim parse_pattern_str() to return list of layer parsed.
pull/1602/head
gaotingquan 2021-12-23 08:57:46 +00:00 committed by Tingquan Gao
parent 56911b573b
commit 721ac0bf61
1 changed files with 72 additions and 124 deletions
ppcls/arch/backbone/base

View File

@ -23,6 +23,7 @@ class TheseusLayer(nn.Layer):
res_dict = {"output": output}
# 'list' is needed to avoid error raised by popping self.res_dict
for res_key in list(self.res_dict):
# clear the res_dict because the forward process may change according to input
res_dict[res_key] = self.res_dict.pop(res_key)
return res_dict
@ -30,7 +31,7 @@ class TheseusLayer(nn.Layer):
self.res_dict[self.res_name] = output
def replace_sub(self, *args, **kwargs) -> None:
msg = "\"replace_sub\" is deprecated, please use \"layer_wrench\" instead."
msg = "The function 'replace_sub()' is deprecated, please use 'layer_wrench()' instead."
logger.error(DeprecationWarning(msg))
raise DeprecationWarning(msg)
@ -43,20 +44,20 @@ class TheseusLayer(nn.Layer):
Args:
layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'.
handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'.
handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'. The formal params are the layer(nn.Layer) and pattern(str) that is (a member of) layer_name_pattern (when layer_name_pattern is List type). And the return is the layer processed.
Returns:
Dict[str, nn.Layer]: The key is the patter and corresponding value is the result returned by 'handle_func'.
Dict[str, nn.Layer]: The key is the pattern and corresponding value is the result returned by 'handle_func()'.
Examples:
from paddle import nn
import paddleclas
def rep_func(sub_layer: nn.Layer, pattern: str):
def rep_func(layer: nn.Layer, pattern: str):
new_layer = nn.Conv2D(
in_channels=sub_layer._in_channels,
out_channels=sub_layer._out_channels,
in_channels=layer._in_channels,
out_channels=layer._out_channels,
kernel_size=5,
padding=2
)
@ -73,25 +74,16 @@ class TheseusLayer(nn.Layer):
handle_res_dict = {}
for pattern in layer_name_pattern:
# find parent layer of sub-layer specified by pattern
sub_layer_parent = None
for target_layer_dict in parse_pattern_str(
pattern=pattern, idx=(0, -1), parent_layer=self):
sub_layer_parent = target_layer_dict["target_layer"]
if not sub_layer_parent:
# parse pattern to find target layer and its parent
layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
if not layer_list:
continue
sub_layer_parent = layer_list[-2]["layer"] if len(
layer_list) > 1 else self
# find sub-layer specified by pattern
sub_layer = None
for target_layer_dict in parse_pattern_str(
pattern=pattern, idx=-1, parent_layer=sub_layer_parent):
sub_layer = target_layer_dict["target_layer"]
sub_layer_name = target_layer_dict["target_layer_name"]
sub_layer_index = target_layer_dict["target_layer_index"]
if not sub_layer:
continue
sub_layer = layer_list[-1]["layer"]
sub_layer_name = layer_list[-1]["name"]
sub_layer_index = layer_list[-1]["index"]
new_sub_layer = handle_func(sub_layer, pattern)
@ -114,65 +106,60 @@ class TheseusLayer(nn.Layer):
bool: 'True' if successful, 'False' otherwise.
"""
to_identity_list = []
layer_list = parse_pattern_str(stop_layer_name, self)
if not layer_list:
return False
for target_layer_dict in parse_pattern_str(stop_layer_name, self):
sub_layer_name = target_layer_dict["target_layer_name"]
sub_layer_index = target_layer_dict["target_layer_index"]
parent_layer = target_layer_dict["parent_layer"]
to_identity_list.append(
(parent_layer, sub_layer_name, sub_layer_index))
for to_identity_layer in to_identity_list:
if not set_identity(*to_identity_layer):
msg = "Failed to set the layers that after stop_layer_name to IdentityLayer."
parent_layer = self
for layer_dict in layer_list:
name, index = layer_dict["name"], layer_dict["index"]
if not set_identity(parent_layer, name, index):
msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'."
logger.warning(msg)
return False
parent_layer = layer_dict["layer"]
return True
def update_res(self,
return_patterns: Union[str, List[str]]) -> Dict[str, bool]:
"""update the results to be returned.
def update_res(
self,
return_patterns: Union[str, List[str]]) -> Dict[str, nn.Layer]:
"""update the result(s) to be returned.
Args:
return_patterns (Union[str, List[str]]): The name of layer to return output.
Returns:
Dict[str, bool]: The pattern(str) is be set successfully if 'True'(bool), failed if 'False'(bool).
Dict[str, nn.Layer]: The pattern(str) and corresponding layer(nn.Layer) that have been set successfully.
"""
# clear res_dict that could have been set
self.res_dict = {}
class Handler(object):
def __init__(self, res_dict):
# res_dict is a reference
self.res_dict = res_dict
def __call__(self, layer, pattern):
layer.res_dict = self.res_dict
layer.res_name = pattern
layer.register_forward_post_hook(layer._save_sub_res_hook)
if hasattr(layer, "hook_remove_helper"):
layer.hook_remove_helper.remove()
layer.hook_remove_helper = layer.register_forward_post_hook(
layer._save_sub_res_hook)
return layer
handle_func = Handler(self.res_dict)
return self.replace_sub(return_patterns, handle_func=handle_func)
res_dict = self.layer_wrench(return_patterns, handle_func=handle_func)
if hasattr(self, "hook_remove_helper"):
self.hook_remove_helper.remove()
self.hook_remove_helper = self.register_forward_post_hook(
self._return_dict_hook)
class WrapLayer(TheseusLayer):
def __init__(self, sub_layer):
super(WrapLayer, self).__init__()
self.sub_layer = sub_layer
def forward(self, *inputs, **kwargs):
return self.sub_layer(*inputs, **kwargs)
def wrap_theseus(sub_layer):
return WrapLayer(sub_layer)
def unwrap_theseus(sub_layer):
if isinstance(sub_layer, WrapLayer):
sub_layer = sub_layer.sub_layer
return sub_layer
return res_dict
def set_identity(parent_layer: nn.Layer,
@ -211,58 +198,30 @@ def set_identity(parent_layer: nn.Layer,
return stop_after
def slice_pattern(pattern: str, idx: Union[Tuple, int]=None) -> List:
"""slice the string type "pattern" to list type by separator ".".
Args:
pattern (str): The pattern to discribe layer name.
idx (Union[Tuple, int], optional): The index(s) of sub-list of list sliced. Defaults to None.
Returns:
List: The sub-list of list sliced by "pattern".
"""
pattern_list = pattern.split(".")
if idx:
if isinstance(idx, Tuple):
if len(idx) == 1:
return pattern_list[idx[0]]
elif len(idx) == 2:
return pattern_list[idx[0]:idx[1]]
else:
msg = f"Only support length of 'idx' is 1 or 2 when 'idx' is a Tuple."
logger.warning(msg)
return None
elif isinstance(idx, int):
return [pattern_list[idx]]
else:
msg = f"Only support type of 'idx' is int or Tuple."
logger.warning(msg)
return None
return pattern_list
def parse_pattern_str(pattern: str, parent_layer: nn.Layer,
idx=None) -> Dict[str, Union[nn.Layer, None, str]]:
def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[
None, List[Dict[str, Union[nn.Layer, str, None]]]]:
"""parse the string type pattern.
Args:
pattern (str): The pattern to discribe layer name.
parent_layer (nn.Layer): The parent layer of target layer(s) specified by "pattern".
idx ([type], optional): [description]. The index(s) of sub-list of list sliced. Defaults to None.
pattern (str): The pattern to discribe layer.
parent_layer (nn.Layer): The root layer relative to the pattern.
Returns:
Dict[str, Union[nn.Layer, None, str]]: Dict["target_layer": Union[nn.Layer, None], "target_layer_name": str, "target_layer_index": str, "parent_layer": nn.Layer]
Yields:
Iterator[Dict[str, Union[nn.Layer, None, str]]]: Dict["target_layer": Union[nn.Layer, None], "target_layer_name": str, "target_layer_index": str, "parent_layer": nn.Layer]
Union[None, List[Dict[str, Union[nn.Layer, str, None]]]]: None if failed. If successfully, the members are layers parsed in order:
[
{"layer": first layer, "name": first layer's name parsed, "index": first layer's index parsed if exist},
{"layer": second layer, "name": second layer's name parsed, "index": second layer's index parsed if exist},
...
]
"""
pattern_list = slice_pattern(pattern, idx)
pattern_list = pattern.split(".")
if not pattern_list:
return None, None, None
msg = f"The pattern('{pattern}') is illegal. Please check and retry."
logger.warning(msg)
return None
layer_list = []
while len(pattern_list) > 0:
if '[' in pattern_list[0]:
target_layer_name = pattern_list[0].split('[')[0]
@ -272,38 +231,27 @@ def parse_pattern_str(pattern: str, parent_layer: nn.Layer,
target_layer_index = None
target_layer = getattr(parent_layer, target_layer_name, None)
target_layer = unwrap_theseus(target_layer)
if target_layer is None:
msg = f"Not found layer named({target_layer_name}) specifed in pattern({pattern})."
msg = f"Not found layer named('{target_layer_name}') specifed in pattern('{pattern}')."
logger.warning(msg)
return {
"target_layer": None,
"target_layer_name": target_layer_name,
"target_layer_index": target_layer_index,
"parent_layer": parent_layer
}
return None
if target_layer_index and target_layer:
if int(target_layer_index) < 0 or int(target_layer_index) >= len(
target_layer):
msg = f"Not found layer by index({target_layer_index}) specifed in pattern({pattern}). The lenght of sub_layer's parent layer is < '{len(parent_layer)}' and > '0'."
msg = f"Not found layer by index('{target_layer_index}') specifed in pattern('{pattern}'). The index should < {len(target_layer)} and > 0."
logger.warning(msg)
return {
"target_layer": None,
"target_layer_name": target_layer_name,
"target_layer_index": target_layer_index,
"parent_layer": parent_layer
}
target_layer = target_layer[target_layer_index]
target_layer = unwrap_theseus(target_layer)
return None
yield {
"target_layer": target_layer,
"target_layer_name": target_layer_name,
"target_layer_index": target_layer_index,
"parent_layer": parent_layer
}
target_layer = target_layer[target_layer_index]
layer_list.append({
"layer": target_layer,
"name": target_layer_name,
"index": target_layer_index
})
pattern_list = pattern_list[1:]
parent_layer = target_layer
return layer_list