refactor: strengthen parse_pattern_str() func

pull/1602/head
gaotingquan 2021-12-21 13:51:45 +00:00 committed by Tingquan Gao
parent 8d0b0d4b0a
commit 0f126b75da
1 changed files with 119 additions and 71 deletions

View File

@ -1,4 +1,4 @@
from typing import List, Dict, Union, Callable, Any from typing import Tuple, List, Dict, Union, Callable, Any
from paddle import nn from paddle import nn
from ppcls.utils import logger from ppcls.utils import logger
@ -61,23 +61,28 @@ class TheseusLayer(nn.Layer):
print(res) print(res)
# {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True} # {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True}
""" """
if not isinstance(layer_name_pattern, list): if not isinstance(layer_name_pattern, list):
layer_name_pattern = [layer_name_pattern] layer_name_pattern = [layer_name_pattern]
handle_res_dict = {} handle_res_dict = {}
for pattern in layer_name_pattern: for pattern in layer_name_pattern:
# pattern_list = pattern.split(".")
# find parent layer of sub-layer specified by pattern # find parent layer of sub-layer specified by pattern
sub_layer_parent, _, _ = parse_pattern_str( sub_layer_parent = None
pattern=pattern, idx=(0, -1), sub_layer_parent=self) for target_layer_dict in parse_pattern_str(
pattern=pattern, idx=(0, -1), parent_layer=self):
sub_layer_parent = target_layer_dict["target_layer"]
if not sub_layer_parent: if not sub_layer_parent:
continue continue
# find sub-layer specified by pattern # find sub-layer specified by pattern
sub_layer, sub_layer_name, sub_layer_index = parse_pattern_str( sub_layer = None
pattern=pattern, idx=-1, sub_layer_parent=sub_layer_parent) for target_layer_dict in parse_pattern_str(
pattern=pattern, idx=-1, parent_layer=sub_layer_parent):
sub_layer = target_layer_dict["target_layer"]
sub_layer_name = target_layer_dict["target_layer_name"]
sub_layer_index = target_layer_dict["target_layer_index"]
if not sub_layer: if not sub_layer:
continue continue
@ -93,26 +98,6 @@ class TheseusLayer(nn.Layer):
handle_res_dict[pattern] = new_sub_layer handle_res_dict[pattern] = new_sub_layer
return handle_res_dict return handle_res_dict
def _set_identity(self, layer, layer_name, layer_index=None):
stop_after = False
for sub_layer_name in layer._sub_layers:
if stop_after:
layer._sub_layers[sub_layer_name] = Identity()
continue
if sub_layer_name == layer_name:
stop_after = True
if layer_index and stop_after:
stop_after = False
for sub_layer_index in layer._sub_layers[layer_name]._sub_layers:
if stop_after:
layer._sub_layers[layer_name][sub_layer_index] = Identity()
continue
if layer_index == sub_layer_index:
stop_after = True
return stop_after
def stop_after(self, stop_layer_name: str) -> bool: def stop_after(self, stop_layer_name: str) -> bool:
"""stop forward and backward after 'stop_layer_name'. """stop forward and backward after 'stop_layer_name'.
@ -122,32 +107,18 @@ class TheseusLayer(nn.Layer):
Returns: Returns:
bool: 'True' if successful, 'False' otherwise. bool: 'True' if successful, 'False' otherwise.
""" """
pattern_list = stop_layer_name.split(".")
to_identity_list = [] to_identity_list = []
# TODO(gaotingquan): replace code by self._parse_pattern_str() for target_layer_dict in parse_pattern_str(stop_layer_name, self):
layer = self sub_layer_name = target_layer_dict["target_layer_name"]
while len(pattern_list) > 0: sub_layer_index = target_layer_dict["target_layer_index"]
layer_parent = layer parent_layer = target_layer_dict["parent_layer"]
if '[' in pattern_list[0]:
sub_layer_name = pattern_list[0].split('[')[0]
sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
layer = getattr(layer, sub_layer_name)[sub_layer_index]
else:
sub_layer_name = pattern_list[0]
sub_layer_index = None
layer = getattr(layer, sub_layer_name, None)
if layer is None:
msg = f"Not found layer by name({pattern_list[0]}) specifed in stop_layer_name({stop_layer_name})."
logger.warning(msg)
return False
to_identity_list.append( to_identity_list.append(
(layer_parent, sub_layer_name, sub_layer_index)) (parent_layer, sub_layer_name, sub_layer_index))
pattern_list = pattern_list[1:]
for to_identity_layer in to_identity_list: for to_identity_layer in to_identity_list:
if not self._set_identity(*to_identity_layer): if not set_identity(*to_identity_layer):
msg = "Failed to set the layers that after stop_layer_name to IdentityLayer." msg = "Failed to set the layers that after stop_layer_name to IdentityLayer."
logger.warning(msg) logger.warning(msg)
return False return False
@ -198,58 +169,135 @@ def unwrap_theseus(sub_layer):
return sub_layer return sub_layer
def slice_pattern(pattern, idx): def set_identity(parent_layer: nn.Layer,
layer_name: str,
layer_index: str=None) -> bool:
"""set the layer specified by layer_name and layer_index to Indentity.
Args:
parent_layer (nn.Layer): The parent layer of target layer specified by layer_name and layer_index.
layer_name (str): The name of target layer to be set to Indentity.
layer_index (str, optional): The index of target layer to be set to Indentity in parent_layer. Defaults to None.
Returns:
bool: True if successfully, False otherwise.
"""
stop_after = False
for sub_layer_name in parent_layer._sub_layers:
if stop_after:
parent_layer._sub_layers[sub_layer_name] = Identity()
continue
if sub_layer_name == layer_name:
stop_after = True
if layer_index and stop_after:
stop_after = False
for sub_layer_index in parent_layer._sub_layers[
layer_name]._sub_layers:
if stop_after:
parent_layer._sub_layers[layer_name][
sub_layer_index] = Identity()
continue
if layer_index == sub_layer_index:
stop_after = True
return stop_after
def slice_pattern(pattern: str, idx: Union[Tuple, int]=None) -> List:
"""slice the string type "pattern" to list type by separator ".".
Args:
pattern (str): The pattern to discribe layer name.
idx (Union[Tuple, int], optional): The index(s) of sub-list of list sliced. Defaults to None.
Returns:
List: The sub-list of list sliced by "pattern".
"""
pattern_list = pattern.split(".") pattern_list = pattern.split(".")
if idx: if idx:
if isinstance(idx, tuple): if isinstance(idx, Tuple):
if len(idx) == 1: if len(idx) == 1:
return pattern_list[idx[0]] return pattern_list[idx[0]]
elif len(idx) == 2: elif len(idx) == 2:
return pattern_list[idx[0]:idx[1]] return pattern_list[idx[0]:idx[1]]
else: else:
msg = f"Only support length of 'idx' is 1 or 2 when 'idx' is a tuple." msg = f"Only support length of 'idx' is 1 or 2 when 'idx' is a Tuple."
logger.warning(msg) logger.warning(msg)
return None return None
elif isinstance(idx, int): elif isinstance(idx, int):
return [pattern_list[idx]] return [pattern_list[idx]]
else: else:
msg = f"Only support type of 'idx' is int or tuple." msg = f"Only support type of 'idx' is int or Tuple."
logger.warning(msg) logger.warning(msg)
return None return None
return pattern_list return pattern_list
def parse_pattern_str(pattern, sub_layer_parent, idx=None): def parse_pattern_str(pattern: str, parent_layer: nn.Layer,
idx=None) -> Dict[str, Union[nn.Layer, None, str]]:
"""parse the string type pattern.
Args:
pattern (str): The pattern to discribe layer name.
parent_layer (nn.Layer): The parent layer of target layer(s) specified by "pattern".
idx ([type], optional): [description]. The index(s) of sub-list of list sliced. Defaults to None.
Returns:
Dict[str, Union[nn.Layer, None, str]]: Dict["target_layer": Union[nn.Layer, None], "target_layer_name": str, "target_layer_index": str, "parent_layer": nn.Layer]
Yields:
Iterator[Dict[str, Union[nn.Layer, None, str]]]: Dict["target_layer": Union[nn.Layer, None], "target_layer_name": str, "target_layer_index": str, "parent_layer": nn.Layer]
"""
pattern_list = slice_pattern(pattern, idx) pattern_list = slice_pattern(pattern, idx)
if not pattern_list: if not pattern_list:
return None, None, None return None, None, None
while len(pattern_list) > 0: while len(pattern_list) > 0:
if '[' in pattern_list[0]: if '[' in pattern_list[0]:
sub_layer_name = pattern_list[0].split('[')[0] target_layer_name = pattern_list[0].split('[')[0]
sub_layer_index = pattern_list[0].split('[')[1].split(']')[0] target_layer_index = pattern_list[0].split('[')[1].split(']')[0]
else: else:
sub_layer_name = pattern_list[0] target_layer_name = pattern_list[0]
sub_layer_index = None target_layer_index = None
sub_layer_parent = getattr(sub_layer_parent, sub_layer_name, None) target_layer = getattr(parent_layer, target_layer_name, None)
sub_layer_parent = unwrap_theseus(sub_layer_parent) target_layer = unwrap_theseus(target_layer)
if sub_layer_parent is None: if target_layer is None:
msg = f"Not found layer named({sub_layer_name}) specifed in pattern({pattern})." msg = f"Not found layer named({target_layer_name}) specifed in pattern({pattern})."
logger.warning(msg) logger.warning(msg)
return None, sub_layer_name, sub_layer_index return {
"target_layer": None,
"target_layer_name": target_layer_name,
"target_layer_index": target_layer_index,
"parent_layer": parent_layer
}
if sub_layer_index and sub_layer_parent: if target_layer_index and target_layer:
if int(sub_layer_index) < 0 or int(sub_layer_index) >= len( if int(target_layer_index) < 0 or int(target_layer_index) >= len(
sub_layer_parent): target_layer):
msg = f"Not found layer by index({sub_layer_index}) specifed in pattern({pattern}). The lenght of sub_layer's parent layer is < '{len(sub_layer_parent)}' and > '0'." msg = f"Not found layer by index({target_layer_index}) specifed in pattern({pattern}). The lenght of sub_layer's parent layer is < '{len(parent_layer)}' and > '0'."
logger.warning(msg) logger.warning(msg)
return None, sub_layer_name, sub_layer_index return {
sub_layer_parent = sub_layer_parent[sub_layer_index] "target_layer": None,
sub_layer_parent = unwrap_theseus(sub_layer_parent) "target_layer_name": target_layer_name,
"target_layer_index": target_layer_index,
"parent_layer": parent_layer
}
target_layer = target_layer[target_layer_index]
target_layer = unwrap_theseus(target_layer)
yield {
"target_layer": target_layer,
"target_layer_name": target_layer_name,
"target_layer_index": target_layer_index,
"parent_layer": parent_layer
}
pattern_list = pattern_list[1:] pattern_list = pattern_list[1:]
parent_layer = target_layer
return sub_layer_parent, sub_layer_name, sub_layer_index