refactor: simplify code
1. remove WrapLayer and wrap_theseus; 2. support call update_res() one more; 3. optim parse_pattern_str() to return list of layer parsed.pull/1602/head
parent
56911b573b
commit
721ac0bf61
ppcls/arch/backbone/base
|
@ -23,6 +23,7 @@ class TheseusLayer(nn.Layer):
|
|||
res_dict = {"output": output}
|
||||
# 'list' is needed to avoid error raised by popping self.res_dict
|
||||
for res_key in list(self.res_dict):
|
||||
# clear the res_dict because the forward process may change according to input
|
||||
res_dict[res_key] = self.res_dict.pop(res_key)
|
||||
return res_dict
|
||||
|
||||
|
@ -30,7 +31,7 @@ class TheseusLayer(nn.Layer):
|
|||
self.res_dict[self.res_name] = output
|
||||
|
||||
def replace_sub(self, *args, **kwargs) -> None:
|
||||
msg = "\"replace_sub\" is deprecated, please use \"layer_wrench\" instead."
|
||||
msg = "The function 'replace_sub()' is deprecated, please use 'layer_wrench()' instead."
|
||||
logger.error(DeprecationWarning(msg))
|
||||
raise DeprecationWarning(msg)
|
||||
|
||||
|
@ -43,20 +44,20 @@ class TheseusLayer(nn.Layer):
|
|||
|
||||
Args:
|
||||
layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'.
|
||||
handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'.
|
||||
handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'. The formal params are the layer(nn.Layer) and pattern(str) that is (a member of) layer_name_pattern (when layer_name_pattern is List type). And the return is the layer processed.
|
||||
|
||||
Returns:
|
||||
Dict[str, nn.Layer]: The key is the patter and corresponding value is the result returned by 'handle_func'.
|
||||
|
||||
Dict[str, nn.Layer]: The key is the pattern and corresponding value is the result returned by 'handle_func()'.
|
||||
|
||||
Examples:
|
||||
|
||||
from paddle import nn
|
||||
import paddleclas
|
||||
|
||||
def rep_func(sub_layer: nn.Layer, pattern: str):
|
||||
def rep_func(layer: nn.Layer, pattern: str):
|
||||
new_layer = nn.Conv2D(
|
||||
in_channels=sub_layer._in_channels,
|
||||
out_channels=sub_layer._out_channels,
|
||||
in_channels=layer._in_channels,
|
||||
out_channels=layer._out_channels,
|
||||
kernel_size=5,
|
||||
padding=2
|
||||
)
|
||||
|
@ -73,25 +74,16 @@ class TheseusLayer(nn.Layer):
|
|||
|
||||
handle_res_dict = {}
|
||||
for pattern in layer_name_pattern:
|
||||
# find parent layer of sub-layer specified by pattern
|
||||
sub_layer_parent = None
|
||||
for target_layer_dict in parse_pattern_str(
|
||||
pattern=pattern, idx=(0, -1), parent_layer=self):
|
||||
sub_layer_parent = target_layer_dict["target_layer"]
|
||||
|
||||
if not sub_layer_parent:
|
||||
# parse pattern to find target layer and its parent
|
||||
layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
|
||||
if not layer_list:
|
||||
continue
|
||||
sub_layer_parent = layer_list[-2]["layer"] if len(
|
||||
layer_list) > 1 else self
|
||||
|
||||
# find sub-layer specified by pattern
|
||||
sub_layer = None
|
||||
for target_layer_dict in parse_pattern_str(
|
||||
pattern=pattern, idx=-1, parent_layer=sub_layer_parent):
|
||||
sub_layer = target_layer_dict["target_layer"]
|
||||
sub_layer_name = target_layer_dict["target_layer_name"]
|
||||
sub_layer_index = target_layer_dict["target_layer_index"]
|
||||
|
||||
if not sub_layer:
|
||||
continue
|
||||
sub_layer = layer_list[-1]["layer"]
|
||||
sub_layer_name = layer_list[-1]["name"]
|
||||
sub_layer_index = layer_list[-1]["index"]
|
||||
|
||||
new_sub_layer = handle_func(sub_layer, pattern)
|
||||
|
||||
|
@ -114,65 +106,60 @@ class TheseusLayer(nn.Layer):
|
|||
bool: 'True' if successful, 'False' otherwise.
|
||||
"""
|
||||
|
||||
to_identity_list = []
|
||||
layer_list = parse_pattern_str(stop_layer_name, self)
|
||||
if not layer_list:
|
||||
return False
|
||||
|
||||
for target_layer_dict in parse_pattern_str(stop_layer_name, self):
|
||||
sub_layer_name = target_layer_dict["target_layer_name"]
|
||||
sub_layer_index = target_layer_dict["target_layer_index"]
|
||||
parent_layer = target_layer_dict["parent_layer"]
|
||||
to_identity_list.append(
|
||||
(parent_layer, sub_layer_name, sub_layer_index))
|
||||
|
||||
for to_identity_layer in to_identity_list:
|
||||
if not set_identity(*to_identity_layer):
|
||||
msg = "Failed to set the layers that after stop_layer_name to IdentityLayer."
|
||||
parent_layer = self
|
||||
for layer_dict in layer_list:
|
||||
name, index = layer_dict["name"], layer_dict["index"]
|
||||
if not set_identity(parent_layer, name, index):
|
||||
msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'."
|
||||
logger.warning(msg)
|
||||
return False
|
||||
parent_layer = layer_dict["layer"]
|
||||
|
||||
return True
|
||||
|
||||
def update_res(self,
|
||||
return_patterns: Union[str, List[str]]) -> Dict[str, bool]:
|
||||
"""update the results to be returned.
|
||||
def update_res(
|
||||
self,
|
||||
return_patterns: Union[str, List[str]]) -> Dict[str, nn.Layer]:
|
||||
"""update the result(s) to be returned.
|
||||
|
||||
Args:
|
||||
return_patterns (Union[str, List[str]]): The name of layer to return output.
|
||||
|
||||
Returns:
|
||||
Dict[str, bool]: The pattern(str) is be set successfully if 'True'(bool), failed if 'False'(bool).
|
||||
Dict[str, nn.Layer]: The pattern(str) and corresponding layer(nn.Layer) that have been set successfully.
|
||||
"""
|
||||
|
||||
# clear res_dict that could have been set
|
||||
self.res_dict = {}
|
||||
|
||||
class Handler(object):
|
||||
def __init__(self, res_dict):
|
||||
# res_dict is a reference
|
||||
self.res_dict = res_dict
|
||||
|
||||
def __call__(self, layer, pattern):
|
||||
layer.res_dict = self.res_dict
|
||||
layer.res_name = pattern
|
||||
layer.register_forward_post_hook(layer._save_sub_res_hook)
|
||||
if hasattr(layer, "hook_remove_helper"):
|
||||
layer.hook_remove_helper.remove()
|
||||
layer.hook_remove_helper = layer.register_forward_post_hook(
|
||||
layer._save_sub_res_hook)
|
||||
return layer
|
||||
|
||||
handle_func = Handler(self.res_dict)
|
||||
|
||||
return self.replace_sub(return_patterns, handle_func=handle_func)
|
||||
res_dict = self.layer_wrench(return_patterns, handle_func=handle_func)
|
||||
|
||||
if hasattr(self, "hook_remove_helper"):
|
||||
self.hook_remove_helper.remove()
|
||||
self.hook_remove_helper = self.register_forward_post_hook(
|
||||
self._return_dict_hook)
|
||||
|
||||
class WrapLayer(TheseusLayer):
|
||||
def __init__(self, sub_layer):
|
||||
super(WrapLayer, self).__init__()
|
||||
self.sub_layer = sub_layer
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
return self.sub_layer(*inputs, **kwargs)
|
||||
|
||||
|
||||
def wrap_theseus(sub_layer):
|
||||
return WrapLayer(sub_layer)
|
||||
|
||||
|
||||
def unwrap_theseus(sub_layer):
|
||||
if isinstance(sub_layer, WrapLayer):
|
||||
sub_layer = sub_layer.sub_layer
|
||||
return sub_layer
|
||||
return res_dict
|
||||
|
||||
|
||||
def set_identity(parent_layer: nn.Layer,
|
||||
|
@ -211,58 +198,30 @@ def set_identity(parent_layer: nn.Layer,
|
|||
return stop_after
|
||||
|
||||
|
||||
def slice_pattern(pattern: str, idx: Union[Tuple, int]=None) -> List:
|
||||
"""slice the string type "pattern" to list type by separator ".".
|
||||
|
||||
Args:
|
||||
pattern (str): The pattern to discribe layer name.
|
||||
idx (Union[Tuple, int], optional): The index(s) of sub-list of list sliced. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List: The sub-list of list sliced by "pattern".
|
||||
"""
|
||||
|
||||
pattern_list = pattern.split(".")
|
||||
if idx:
|
||||
if isinstance(idx, Tuple):
|
||||
if len(idx) == 1:
|
||||
return pattern_list[idx[0]]
|
||||
elif len(idx) == 2:
|
||||
return pattern_list[idx[0]:idx[1]]
|
||||
else:
|
||||
msg = f"Only support length of 'idx' is 1 or 2 when 'idx' is a Tuple."
|
||||
logger.warning(msg)
|
||||
return None
|
||||
elif isinstance(idx, int):
|
||||
return [pattern_list[idx]]
|
||||
else:
|
||||
msg = f"Only support type of 'idx' is int or Tuple."
|
||||
logger.warning(msg)
|
||||
return None
|
||||
|
||||
return pattern_list
|
||||
|
||||
|
||||
def parse_pattern_str(pattern: str, parent_layer: nn.Layer,
|
||||
idx=None) -> Dict[str, Union[nn.Layer, None, str]]:
|
||||
def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[
|
||||
None, List[Dict[str, Union[nn.Layer, str, None]]]]:
|
||||
"""parse the string type pattern.
|
||||
|
||||
Args:
|
||||
pattern (str): The pattern to discribe layer name.
|
||||
parent_layer (nn.Layer): The parent layer of target layer(s) specified by "pattern".
|
||||
idx ([type], optional): [description]. The index(s) of sub-list of list sliced. Defaults to None.
|
||||
pattern (str): The pattern to discribe layer.
|
||||
parent_layer (nn.Layer): The root layer relative to the pattern.
|
||||
|
||||
Returns:
|
||||
Dict[str, Union[nn.Layer, None, str]]: Dict["target_layer": Union[nn.Layer, None], "target_layer_name": str, "target_layer_index": str, "parent_layer": nn.Layer]
|
||||
|
||||
Yields:
|
||||
Iterator[Dict[str, Union[nn.Layer, None, str]]]: Dict["target_layer": Union[nn.Layer, None], "target_layer_name": str, "target_layer_index": str, "parent_layer": nn.Layer]
|
||||
Union[None, List[Dict[str, Union[nn.Layer, str, None]]]]: None if failed. If successfully, the members are layers parsed in order:
|
||||
[
|
||||
{"layer": first layer, "name": first layer's name parsed, "index": first layer's index parsed if exist},
|
||||
{"layer": second layer, "name": second layer's name parsed, "index": second layer's index parsed if exist},
|
||||
...
|
||||
]
|
||||
"""
|
||||
|
||||
pattern_list = slice_pattern(pattern, idx)
|
||||
pattern_list = pattern.split(".")
|
||||
if not pattern_list:
|
||||
return None, None, None
|
||||
msg = f"The pattern('{pattern}') is illegal. Please check and retry."
|
||||
logger.warning(msg)
|
||||
return None
|
||||
|
||||
layer_list = []
|
||||
while len(pattern_list) > 0:
|
||||
if '[' in pattern_list[0]:
|
||||
target_layer_name = pattern_list[0].split('[')[0]
|
||||
|
@ -272,38 +231,27 @@ def parse_pattern_str(pattern: str, parent_layer: nn.Layer,
|
|||
target_layer_index = None
|
||||
|
||||
target_layer = getattr(parent_layer, target_layer_name, None)
|
||||
target_layer = unwrap_theseus(target_layer)
|
||||
|
||||
if target_layer is None:
|
||||
msg = f"Not found layer named({target_layer_name}) specifed in pattern({pattern})."
|
||||
msg = f"Not found layer named('{target_layer_name}') specifed in pattern('{pattern}')."
|
||||
logger.warning(msg)
|
||||
return {
|
||||
"target_layer": None,
|
||||
"target_layer_name": target_layer_name,
|
||||
"target_layer_index": target_layer_index,
|
||||
"parent_layer": parent_layer
|
||||
}
|
||||
return None
|
||||
|
||||
if target_layer_index and target_layer:
|
||||
if int(target_layer_index) < 0 or int(target_layer_index) >= len(
|
||||
target_layer):
|
||||
msg = f"Not found layer by index({target_layer_index}) specifed in pattern({pattern}). The lenght of sub_layer's parent layer is < '{len(parent_layer)}' and > '0'."
|
||||
msg = f"Not found layer by index('{target_layer_index}') specifed in pattern('{pattern}'). The index should < {len(target_layer)} and > 0."
|
||||
logger.warning(msg)
|
||||
return {
|
||||
"target_layer": None,
|
||||
"target_layer_name": target_layer_name,
|
||||
"target_layer_index": target_layer_index,
|
||||
"parent_layer": parent_layer
|
||||
}
|
||||
target_layer = target_layer[target_layer_index]
|
||||
target_layer = unwrap_theseus(target_layer)
|
||||
return None
|
||||
|
||||
yield {
|
||||
"target_layer": target_layer,
|
||||
"target_layer_name": target_layer_name,
|
||||
"target_layer_index": target_layer_index,
|
||||
"parent_layer": parent_layer
|
||||
}
|
||||
target_layer = target_layer[target_layer_index]
|
||||
|
||||
layer_list.append({
|
||||
"layer": target_layer,
|
||||
"name": target_layer_name,
|
||||
"index": target_layer_index
|
||||
})
|
||||
|
||||
pattern_list = pattern_list[1:]
|
||||
parent_layer = target_layer
|
||||
return layer_list
|
||||
|
|
Loading…
Reference in New Issue