mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fix: fix the error that containers nesting cannot be handled.
the error would be raised when when the pattern string represents nested, e.g., containing "[3][1]".
This commit is contained in:
parent
a1baf3f476
commit
a75dc8c993
@ -103,7 +103,7 @@ class TheseusLayer(nn.Layer):
|
||||
return new_layer
|
||||
|
||||
net = paddleclas.MobileNetV1()
|
||||
res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
|
||||
res = net.upgrade_sublayer(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
|
||||
print(res)
|
||||
# {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
|
||||
"""
|
||||
@ -117,18 +117,26 @@ class TheseusLayer(nn.Layer):
|
||||
layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
|
||||
if not layer_list:
|
||||
continue
|
||||
|
||||
sub_layer_parent = layer_list[-2]["layer"] if len(
|
||||
layer_list) > 1 else self
|
||||
|
||||
sub_layer = layer_list[-1]["layer"]
|
||||
sub_layer_name = layer_list[-1]["name"]
|
||||
sub_layer_index = layer_list[-1]["index"]
|
||||
sub_layer_index_list = layer_list[-1]["index_list"]
|
||||
|
||||
new_sub_layer = handle_func(sub_layer, pattern)
|
||||
|
||||
if sub_layer_index:
|
||||
getattr(sub_layer_parent,
|
||||
sub_layer_name)[sub_layer_index] = new_sub_layer
|
||||
if sub_layer_index_list:
|
||||
if len(sub_layer_index_list) > 1:
|
||||
sub_layer_parent = getattr(
|
||||
sub_layer_parent,
|
||||
sub_layer_name)[sub_layer_index_list[0]]
|
||||
for sub_layer_index in sub_layer_index_list[1:-1]:
|
||||
sub_layer_parent = sub_layer_parent[sub_layer_index]
|
||||
sub_layer_parent[sub_layer_index_list[-1]] = new_sub_layer
|
||||
else:
|
||||
getattr(sub_layer_parent, sub_layer_name)[
|
||||
sub_layer_index_list[0]] = new_sub_layer
|
||||
else:
|
||||
setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
|
||||
|
||||
@ -151,8 +159,8 @@ class TheseusLayer(nn.Layer):
|
||||
|
||||
parent_layer = self
|
||||
for layer_dict in layer_list:
|
||||
name, index = layer_dict["name"], layer_dict["index"]
|
||||
if not set_identity(parent_layer, name, index):
|
||||
name, index_list = layer_dict["name"], layer_dict["index_list"]
|
||||
if not set_identity(parent_layer, name, index_list):
|
||||
msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'."
|
||||
logger.warning(msg)
|
||||
return False
|
||||
@ -208,13 +216,13 @@ def save_sub_res_hook(layer, input, output):
|
||||
|
||||
def set_identity(parent_layer: nn.Layer,
|
||||
layer_name: str,
|
||||
layer_index: str=None) -> bool:
|
||||
"""set the layer specified by layer_name and layer_index to Indentity.
|
||||
layer_index_list: str=None) -> bool:
|
||||
"""set the layer specified by layer_name and layer_index_list to Indentity.
|
||||
|
||||
Args:
|
||||
parent_layer (nn.Layer): The parent layer of target layer specified by layer_name and layer_index.
|
||||
parent_layer (nn.Layer): The parent layer of target layer specified by layer_name and layer_index_list.
|
||||
layer_name (str): The name of target layer to be set to Indentity.
|
||||
layer_index (str, optional): The index of target layer to be set to Indentity in parent_layer. Defaults to None.
|
||||
layer_index_list (str, optional): The index of target layer to be set to Indentity in parent_layer. Defaults to None.
|
||||
|
||||
Returns:
|
||||
bool: True if successfully, False otherwise.
|
||||
@ -228,16 +236,19 @@ def set_identity(parent_layer: nn.Layer,
|
||||
if sub_layer_name == layer_name:
|
||||
stop_after = True
|
||||
|
||||
if layer_index and stop_after:
|
||||
stop_after = False
|
||||
for sub_layer_index in parent_layer._sub_layers[
|
||||
layer_name]._sub_layers:
|
||||
if stop_after:
|
||||
parent_layer._sub_layers[layer_name][
|
||||
sub_layer_index] = Identity()
|
||||
continue
|
||||
if layer_index == sub_layer_index:
|
||||
stop_after = True
|
||||
if layer_index_list and stop_after:
|
||||
layer_container = parent_layer._sub_layers[layer_name]
|
||||
for num, layer_index in enumerate(layer_index_list):
|
||||
stop_after = False
|
||||
for i in range(num):
|
||||
layer_container = layer_container[layer_index_list[i]]
|
||||
for sub_layer_index in layer_container._sub_layers:
|
||||
if stop_after:
|
||||
parent_layer._sub_layers[layer_name][
|
||||
sub_layer_index] = Identity()
|
||||
continue
|
||||
if layer_index == sub_layer_index:
|
||||
stop_after = True
|
||||
|
||||
return stop_after
|
||||
|
||||
@ -269,10 +280,12 @@ def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[
|
||||
while len(pattern_list) > 0:
|
||||
if '[' in pattern_list[0]:
|
||||
target_layer_name = pattern_list[0].split('[')[0]
|
||||
target_layer_index = pattern_list[0].split('[')[1].split(']')[0]
|
||||
target_layer_index_list = list(
|
||||
index.split(']')[0]
|
||||
for index in pattern_list[0].split('[')[1:])
|
||||
else:
|
||||
target_layer_name = pattern_list[0]
|
||||
target_layer_index = None
|
||||
target_layer_index_list = None
|
||||
|
||||
target_layer = getattr(parent_layer, target_layer_name, None)
|
||||
|
||||
@ -281,21 +294,22 @@ def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[
|
||||
logger.warning(msg)
|
||||
return None
|
||||
|
||||
if target_layer_index and target_layer:
|
||||
if int(target_layer_index) < 0 or int(target_layer_index) >= len(
|
||||
target_layer):
|
||||
msg = f"Not found layer by index('{target_layer_index}') specifed in pattern('{pattern}'). The index should < {len(target_layer)} and > 0."
|
||||
logger.warning(msg)
|
||||
return None
|
||||
|
||||
target_layer = target_layer[target_layer_index]
|
||||
if target_layer_index_list:
|
||||
for target_layer_index in target_layer_index_list:
|
||||
if int(target_layer_index) < 0 or int(
|
||||
target_layer_index) >= len(target_layer):
|
||||
msg = f"Not found layer by index('{target_layer_index}') specifed in pattern('{pattern}'). The index should < {len(target_layer)} and > 0."
|
||||
logger.warning(msg)
|
||||
return None
|
||||
target_layer = target_layer[target_layer_index]
|
||||
|
||||
layer_list.append({
|
||||
"layer": target_layer,
|
||||
"name": target_layer_name,
|
||||
"index": target_layer_index
|
||||
"index_list": target_layer_index_list
|
||||
})
|
||||
|
||||
pattern_list = pattern_list[1:]
|
||||
parent_layer = target_layer
|
||||
|
||||
return layer_list
|
||||
|
Loading…
x
Reference in New Issue
Block a user