mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fix: fix the error that containers nesting cannot be handled.
the error would be raised when when the pattern string represents nested, e.g., containing "[3][1]".
This commit is contained in:
parent
a1baf3f476
commit
a75dc8c993
@ -103,7 +103,7 @@ class TheseusLayer(nn.Layer):
|
|||||||
return new_layer
|
return new_layer
|
||||||
|
|
||||||
net = paddleclas.MobileNetV1()
|
net = paddleclas.MobileNetV1()
|
||||||
res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
|
res = net.upgrade_sublayer(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
|
||||||
print(res)
|
print(res)
|
||||||
# {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
|
# {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
|
||||||
"""
|
"""
|
||||||
@ -117,18 +117,26 @@ class TheseusLayer(nn.Layer):
|
|||||||
layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
|
layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
|
||||||
if not layer_list:
|
if not layer_list:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
sub_layer_parent = layer_list[-2]["layer"] if len(
|
sub_layer_parent = layer_list[-2]["layer"] if len(
|
||||||
layer_list) > 1 else self
|
layer_list) > 1 else self
|
||||||
|
|
||||||
sub_layer = layer_list[-1]["layer"]
|
sub_layer = layer_list[-1]["layer"]
|
||||||
sub_layer_name = layer_list[-1]["name"]
|
sub_layer_name = layer_list[-1]["name"]
|
||||||
sub_layer_index = layer_list[-1]["index"]
|
sub_layer_index_list = layer_list[-1]["index_list"]
|
||||||
|
|
||||||
new_sub_layer = handle_func(sub_layer, pattern)
|
new_sub_layer = handle_func(sub_layer, pattern)
|
||||||
|
|
||||||
if sub_layer_index:
|
if sub_layer_index_list:
|
||||||
getattr(sub_layer_parent,
|
if len(sub_layer_index_list) > 1:
|
||||||
sub_layer_name)[sub_layer_index] = new_sub_layer
|
sub_layer_parent = getattr(
|
||||||
|
sub_layer_parent,
|
||||||
|
sub_layer_name)[sub_layer_index_list[0]]
|
||||||
|
for sub_layer_index in sub_layer_index_list[1:-1]:
|
||||||
|
sub_layer_parent = sub_layer_parent[sub_layer_index]
|
||||||
|
sub_layer_parent[sub_layer_index_list[-1]] = new_sub_layer
|
||||||
|
else:
|
||||||
|
getattr(sub_layer_parent, sub_layer_name)[
|
||||||
|
sub_layer_index_list[0]] = new_sub_layer
|
||||||
else:
|
else:
|
||||||
setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
|
setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
|
||||||
|
|
||||||
@ -151,8 +159,8 @@ class TheseusLayer(nn.Layer):
|
|||||||
|
|
||||||
parent_layer = self
|
parent_layer = self
|
||||||
for layer_dict in layer_list:
|
for layer_dict in layer_list:
|
||||||
name, index = layer_dict["name"], layer_dict["index"]
|
name, index_list = layer_dict["name"], layer_dict["index_list"]
|
||||||
if not set_identity(parent_layer, name, index):
|
if not set_identity(parent_layer, name, index_list):
|
||||||
msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'."
|
msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'."
|
||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
return False
|
return False
|
||||||
@ -208,13 +216,13 @@ def save_sub_res_hook(layer, input, output):
|
|||||||
|
|
||||||
def set_identity(parent_layer: nn.Layer,
|
def set_identity(parent_layer: nn.Layer,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
layer_index: str=None) -> bool:
|
layer_index_list: str=None) -> bool:
|
||||||
"""set the layer specified by layer_name and layer_index to Indentity.
|
"""set the layer specified by layer_name and layer_index_list to Indentity.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
parent_layer (nn.Layer): The parent layer of target layer specified by layer_name and layer_index.
|
parent_layer (nn.Layer): The parent layer of target layer specified by layer_name and layer_index_list.
|
||||||
layer_name (str): The name of target layer to be set to Indentity.
|
layer_name (str): The name of target layer to be set to Indentity.
|
||||||
layer_index (str, optional): The index of target layer to be set to Indentity in parent_layer. Defaults to None.
|
layer_index_list (str, optional): The index of target layer to be set to Indentity in parent_layer. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if successfully, False otherwise.
|
bool: True if successfully, False otherwise.
|
||||||
@ -228,10 +236,13 @@ def set_identity(parent_layer: nn.Layer,
|
|||||||
if sub_layer_name == layer_name:
|
if sub_layer_name == layer_name:
|
||||||
stop_after = True
|
stop_after = True
|
||||||
|
|
||||||
if layer_index and stop_after:
|
if layer_index_list and stop_after:
|
||||||
|
layer_container = parent_layer._sub_layers[layer_name]
|
||||||
|
for num, layer_index in enumerate(layer_index_list):
|
||||||
stop_after = False
|
stop_after = False
|
||||||
for sub_layer_index in parent_layer._sub_layers[
|
for i in range(num):
|
||||||
layer_name]._sub_layers:
|
layer_container = layer_container[layer_index_list[i]]
|
||||||
|
for sub_layer_index in layer_container._sub_layers:
|
||||||
if stop_after:
|
if stop_after:
|
||||||
parent_layer._sub_layers[layer_name][
|
parent_layer._sub_layers[layer_name][
|
||||||
sub_layer_index] = Identity()
|
sub_layer_index] = Identity()
|
||||||
@ -269,10 +280,12 @@ def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[
|
|||||||
while len(pattern_list) > 0:
|
while len(pattern_list) > 0:
|
||||||
if '[' in pattern_list[0]:
|
if '[' in pattern_list[0]:
|
||||||
target_layer_name = pattern_list[0].split('[')[0]
|
target_layer_name = pattern_list[0].split('[')[0]
|
||||||
target_layer_index = pattern_list[0].split('[')[1].split(']')[0]
|
target_layer_index_list = list(
|
||||||
|
index.split(']')[0]
|
||||||
|
for index in pattern_list[0].split('[')[1:])
|
||||||
else:
|
else:
|
||||||
target_layer_name = pattern_list[0]
|
target_layer_name = pattern_list[0]
|
||||||
target_layer_index = None
|
target_layer_index_list = None
|
||||||
|
|
||||||
target_layer = getattr(parent_layer, target_layer_name, None)
|
target_layer = getattr(parent_layer, target_layer_name, None)
|
||||||
|
|
||||||
@ -281,21 +294,22 @@ def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[
|
|||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if target_layer_index and target_layer:
|
if target_layer_index_list:
|
||||||
if int(target_layer_index) < 0 or int(target_layer_index) >= len(
|
for target_layer_index in target_layer_index_list:
|
||||||
target_layer):
|
if int(target_layer_index) < 0 or int(
|
||||||
|
target_layer_index) >= len(target_layer):
|
||||||
msg = f"Not found layer by index('{target_layer_index}') specifed in pattern('{pattern}'). The index should < {len(target_layer)} and > 0."
|
msg = f"Not found layer by index('{target_layer_index}') specifed in pattern('{pattern}'). The index should < {len(target_layer)} and > 0."
|
||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
target_layer = target_layer[target_layer_index]
|
target_layer = target_layer[target_layer_index]
|
||||||
|
|
||||||
layer_list.append({
|
layer_list.append({
|
||||||
"layer": target_layer,
|
"layer": target_layer,
|
||||||
"name": target_layer_name,
|
"name": target_layer_name,
|
||||||
"index": target_layer_index
|
"index_list": target_layer_index_list
|
||||||
})
|
})
|
||||||
|
|
||||||
pattern_list = pattern_list[1:]
|
pattern_list = pattern_list[1:]
|
||||||
parent_layer = target_layer
|
parent_layer = target_layer
|
||||||
|
|
||||||
return layer_list
|
return layer_list
|
||||||
|
Loading…
x
Reference in New Issue
Block a user