fix: fix the error that containers nesting cannot be handled.

the error would be raised when when the pattern string represents nested, e.g., containing "[3][1]".
This commit is contained in:
gaotingquan 2022-08-24 03:24:33 +00:00 committed by Tingquan Gao
parent a1baf3f476
commit a75dc8c993

View File

@ -103,7 +103,7 @@ class TheseusLayer(nn.Layer):
return new_layer return new_layer
net = paddleclas.MobileNetV1() net = paddleclas.MobileNetV1()
res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func) res = net.upgrade_sublayer(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
print(res) print(res)
# {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer} # {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
""" """
@ -117,18 +117,26 @@ class TheseusLayer(nn.Layer):
layer_list = parse_pattern_str(pattern=pattern, parent_layer=self) layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
if not layer_list: if not layer_list:
continue continue
sub_layer_parent = layer_list[-2]["layer"] if len( sub_layer_parent = layer_list[-2]["layer"] if len(
layer_list) > 1 else self layer_list) > 1 else self
sub_layer = layer_list[-1]["layer"] sub_layer = layer_list[-1]["layer"]
sub_layer_name = layer_list[-1]["name"] sub_layer_name = layer_list[-1]["name"]
sub_layer_index = layer_list[-1]["index"] sub_layer_index_list = layer_list[-1]["index_list"]
new_sub_layer = handle_func(sub_layer, pattern) new_sub_layer = handle_func(sub_layer, pattern)
if sub_layer_index: if sub_layer_index_list:
getattr(sub_layer_parent, if len(sub_layer_index_list) > 1:
sub_layer_name)[sub_layer_index] = new_sub_layer sub_layer_parent = getattr(
sub_layer_parent,
sub_layer_name)[sub_layer_index_list[0]]
for sub_layer_index in sub_layer_index_list[1:-1]:
sub_layer_parent = sub_layer_parent[sub_layer_index]
sub_layer_parent[sub_layer_index_list[-1]] = new_sub_layer
else:
getattr(sub_layer_parent, sub_layer_name)[
sub_layer_index_list[0]] = new_sub_layer
else: else:
setattr(sub_layer_parent, sub_layer_name, new_sub_layer) setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
@ -151,8 +159,8 @@ class TheseusLayer(nn.Layer):
parent_layer = self parent_layer = self
for layer_dict in layer_list: for layer_dict in layer_list:
name, index = layer_dict["name"], layer_dict["index"] name, index_list = layer_dict["name"], layer_dict["index_list"]
if not set_identity(parent_layer, name, index): if not set_identity(parent_layer, name, index_list):
msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'." msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'."
logger.warning(msg) logger.warning(msg)
return False return False
@ -208,13 +216,13 @@ def save_sub_res_hook(layer, input, output):
def set_identity(parent_layer: nn.Layer, def set_identity(parent_layer: nn.Layer,
layer_name: str, layer_name: str,
layer_index: str=None) -> bool: layer_index_list: str=None) -> bool:
"""set the layer specified by layer_name and layer_index to Indentity. """set the layer specified by layer_name and layer_index_list to Indentity.
Args: Args:
parent_layer (nn.Layer): The parent layer of target layer specified by layer_name and layer_index. parent_layer (nn.Layer): The parent layer of target layer specified by layer_name and layer_index_list.
layer_name (str): The name of target layer to be set to Indentity. layer_name (str): The name of target layer to be set to Indentity.
layer_index (str, optional): The index of target layer to be set to Indentity in parent_layer. Defaults to None. layer_index_list (str, optional): The index of target layer to be set to Indentity in parent_layer. Defaults to None.
Returns: Returns:
bool: True if successfully, False otherwise. bool: True if successfully, False otherwise.
@ -228,10 +236,13 @@ def set_identity(parent_layer: nn.Layer,
if sub_layer_name == layer_name: if sub_layer_name == layer_name:
stop_after = True stop_after = True
if layer_index and stop_after: if layer_index_list and stop_after:
layer_container = parent_layer._sub_layers[layer_name]
for num, layer_index in enumerate(layer_index_list):
stop_after = False stop_after = False
for sub_layer_index in parent_layer._sub_layers[ for i in range(num):
layer_name]._sub_layers: layer_container = layer_container[layer_index_list[i]]
for sub_layer_index in layer_container._sub_layers:
if stop_after: if stop_after:
parent_layer._sub_layers[layer_name][ parent_layer._sub_layers[layer_name][
sub_layer_index] = Identity() sub_layer_index] = Identity()
@ -269,10 +280,12 @@ def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[
while len(pattern_list) > 0: while len(pattern_list) > 0:
if '[' in pattern_list[0]: if '[' in pattern_list[0]:
target_layer_name = pattern_list[0].split('[')[0] target_layer_name = pattern_list[0].split('[')[0]
target_layer_index = pattern_list[0].split('[')[1].split(']')[0] target_layer_index_list = list(
index.split(']')[0]
for index in pattern_list[0].split('[')[1:])
else: else:
target_layer_name = pattern_list[0] target_layer_name = pattern_list[0]
target_layer_index = None target_layer_index_list = None
target_layer = getattr(parent_layer, target_layer_name, None) target_layer = getattr(parent_layer, target_layer_name, None)
@ -281,21 +294,22 @@ def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[
logger.warning(msg) logger.warning(msg)
return None return None
if target_layer_index and target_layer: if target_layer_index_list:
if int(target_layer_index) < 0 or int(target_layer_index) >= len( for target_layer_index in target_layer_index_list:
target_layer): if int(target_layer_index) < 0 or int(
target_layer_index) >= len(target_layer):
msg = f"Not found layer by index('{target_layer_index}') specifed in pattern('{pattern}'). The index should < {len(target_layer)} and > 0." msg = f"Not found layer by index('{target_layer_index}') specifed in pattern('{pattern}'). The index should < {len(target_layer)} and > 0."
logger.warning(msg) logger.warning(msg)
return None return None
target_layer = target_layer[target_layer_index] target_layer = target_layer[target_layer_index]
layer_list.append({ layer_list.append({
"layer": target_layer, "layer": target_layer,
"name": target_layer_name, "name": target_layer_name,
"index": target_layer_index "index_list": target_layer_index_list
}) })
pattern_list = pattern_list[1:] pattern_list = pattern_list[1:]
parent_layer = target_layer parent_layer = target_layer
return layer_list return layer_list