fix: fix problems commented in reviewing
parent
41296972a4
commit
18dec0744a
|
@ -21,6 +21,7 @@ class TheseusLayer(nn.Layer):
|
|||
|
||||
def _return_dict_hook(self, layer, input, output):
|
||||
res_dict = {"output": output}
|
||||
# 'list' is needed to avoid error raised by popping self.res_dict
|
||||
for res_key in list(self.res_dict):
|
||||
res_dict[res_key] = self.res_dict.pop(res_key)
|
||||
return res_dict
|
||||
|
@ -28,12 +29,44 @@ class TheseusLayer(nn.Layer):
|
|||
def _save_sub_res_hook(self, layer, input, output):
|
||||
self.res_dict[self.res_name] = output
|
||||
|
||||
def _find_layers_handle(self, patterns, handle_func):
|
||||
def replace_sub(self,
|
||||
layer_name_pattern: Union[str, List[str]],
|
||||
handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[
|
||||
str, nn.Layer]:
|
||||
"""use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.
|
||||
|
||||
Args:
|
||||
layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'.
|
||||
handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'.
|
||||
|
||||
Returns:
|
||||
Dict[str, nn.Layer]: The key is the patter and corresponding value is the result returned by 'handle_func'.
|
||||
|
||||
Examples:
|
||||
|
||||
from paddle import nn
|
||||
import paddleclas
|
||||
|
||||
def rep_func(sub_layer: nn.Layer, pattern: str):
|
||||
new_layer = nn.Conv2D(
|
||||
in_channels=sub_layer._in_channels,
|
||||
out_channels=sub_layer._out_channels,
|
||||
kernel_size=5,
|
||||
padding=2
|
||||
)
|
||||
return new_layer
|
||||
|
||||
net = paddleclas.MobileNetV1()
|
||||
res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
|
||||
print(res)
|
||||
# {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True}
|
||||
"""
|
||||
if not isinstance(layer_name_pattern, list):
|
||||
layer_name_pattern = [layer_name_pattern]
|
||||
|
||||
handle_res_dict = {}
|
||||
for pattern in patterns:
|
||||
for pattern in layer_name_pattern:
|
||||
pattern_list = pattern.split(".")
|
||||
if not pattern_list:
|
||||
continue
|
||||
|
||||
# find parent layer of sub-layer specified by pattern
|
||||
sub_layer_parent = self
|
||||
|
@ -65,32 +98,30 @@ class TheseusLayer(nn.Layer):
|
|||
sub_layer_name = pattern_list[0]
|
||||
sub_layer_index = None
|
||||
|
||||
sub_layer = getattr(sub_layer_parent, sub_layer_name, False)
|
||||
sub_layer = getattr(sub_layer_parent, sub_layer_name, None)
|
||||
|
||||
if sub_layer is False:
|
||||
if not sub_layer:
|
||||
msg = f"Not found sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})."
|
||||
logger.warning(msg)
|
||||
continue
|
||||
|
||||
try:
|
||||
sub_layer = sub_layer[
|
||||
sub_layer_index] if sub_layer_index is not None else sub_layer
|
||||
except KeyError as e:
|
||||
msg = f"Not found sub-layer by index({sub_layer_index}) specifed in pattern({pattern})."
|
||||
logger.warning(msg)
|
||||
continue
|
||||
if sub_layer_index is not None:
|
||||
if int(sub_layer_index) < 0 or int(sub_layer_index) >= len(
|
||||
sub_layer):
|
||||
msg = f"Not found sub-layer by index({sub_layer_index}) specifed in pattern({pattern})."
|
||||
logger.warning(msg)
|
||||
continue
|
||||
sub_layer = sub_layer[sub_layer_index]
|
||||
|
||||
if not isinstance(sub_layer, TheseusLayer):
|
||||
sub_layer = wrap_theseus(sub_layer)
|
||||
new_sub_layer = handle_func(sub_layer, pattern)
|
||||
|
||||
if sub_layer_index:
|
||||
getattr(sub_layer_parent,
|
||||
sub_layer_name)[sub_layer_index] = sub_layer
|
||||
sub_layer_name)[sub_layer_index] = new_sub_layer
|
||||
else:
|
||||
setattr(sub_layer_parent, sub_layer_name, sub_layer)
|
||||
setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
|
||||
|
||||
handle_res = handle_func(sub_layer, pattern)
|
||||
handle_res_dict[pattern] = handle_res
|
||||
handle_res_dict[pattern] = new_sub_layer
|
||||
return handle_res_dict
|
||||
|
||||
def _set_identity(self, layer, layer_name, layer_index=None):
|
||||
|
@ -113,45 +144,6 @@ class TheseusLayer(nn.Layer):
|
|||
|
||||
return stop_after
|
||||
|
||||
def replace_sub(self,
|
||||
layer_name_pattern: Union[str, List[str]],
|
||||
replace_function: Callable[[nn.Layer, str], Any]) -> Any:
|
||||
"""use 'replace_function' to modify the 'layer_name_pattern'.
|
||||
|
||||
Args:
|
||||
layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'replace_function'.
|
||||
replace_function (FunctionType): The function to modify target layer specified by 'layer_name_pattern'.
|
||||
|
||||
Returns:
|
||||
bool: 'True' if successful, 'False' otherwise.
|
||||
|
||||
Examples:
|
||||
|
||||
from paddle import nn
|
||||
import paddleclas
|
||||
|
||||
def rep_func(warp_layer: nn.Layer, pattern: str):
|
||||
sub_layer = warp_layer.sub_layer
|
||||
new_layer = nn.Conv2D(
|
||||
in_channels=sub_layer._in_channels,
|
||||
out_channels=sub_layer._out_channels,
|
||||
kernel_size=5
|
||||
)
|
||||
warp_layer.sub_layer = new_layer
|
||||
return True
|
||||
|
||||
net = paddleclas.MobileNetV1()
|
||||
res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], replace_function=rep_func)
|
||||
print(res)
|
||||
# {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True}
|
||||
"""
|
||||
|
||||
if not isinstance(layer_name_pattern, list):
|
||||
layer_name_pattern = [layer_name_pattern]
|
||||
return self._find_layers_handle(
|
||||
layer_name_pattern, handle_func=replace_function)
|
||||
|
||||
# TODO(weishengyu): stop doesn't work when stop layer has a parallel branch.
|
||||
def stop_after(self, stop_layer_name: str) -> bool:
|
||||
"""stop forward and backward after 'stop_layer_name'.
|
||||
|
||||
|
@ -210,15 +202,11 @@ class TheseusLayer(nn.Layer):
|
|||
layer.res_dict = self.res_dict
|
||||
layer.res_name = pattern
|
||||
layer.register_forward_post_hook(layer._save_sub_res_hook)
|
||||
return True
|
||||
return layer
|
||||
|
||||
handle_func = Handler(self.res_dict)
|
||||
|
||||
if not isinstance(return_patterns, list):
|
||||
return_patterns = [return_patterns]
|
||||
|
||||
return self._find_layers_handle(
|
||||
return_patterns, handle_func=handle_func)
|
||||
return self.replace_sub(return_patterns, handle_func=handle_func)
|
||||
|
||||
|
||||
class WrapLayer(TheseusLayer):
|
||||
|
|
Loading…
Reference in New Issue