fix: fix result returned by stop_after

pull/1602/head
gaotingquan 2021-12-20 06:34:11 +00:00 committed by Tingquan Gao
parent cf205e1379
commit f8ee6c0f86
1 changed files with 32 additions and 15 deletions

View File

@ -1,4 +1,4 @@
from typing import List, Union, Callable, Any from typing import List, Dict, Union, Callable, Any
from paddle import nn from paddle import nn
from ppcls.utils import logger from ppcls.utils import logger
@ -19,7 +19,6 @@ class TheseusLayer(nn.Layer):
self.pruner = None self.pruner = None
self.quanter = None self.quanter = None
# TODO(gaotingquan): weishengyu
def _return_dict_hook(self, layer, input, output): def _return_dict_hook(self, layer, input, output):
res_dict = {"output": output} res_dict = {"output": output}
for res_key in list(self.res_dict): for res_key in list(self.res_dict):
@ -54,7 +53,7 @@ class TheseusLayer(nn.Layer):
sub_layer_parent = sub_layer_parent.sub_layer sub_layer_parent = sub_layer_parent.sub_layer
pattern_list = pattern_list[1:] pattern_list = pattern_list[1:]
if sub_layer_parent is None: if sub_layer_parent is None:
msg = f"Not found layer by name({pattern_list[0]}) specifed in pattern({pattern})." msg = f"Not found parent layer of sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})."
logger.warning(msg) logger.warning(msg)
continue continue
@ -62,17 +61,33 @@ class TheseusLayer(nn.Layer):
if '[' in pattern_list[0]: if '[' in pattern_list[0]:
sub_layer_name = pattern_list[0].split('[')[0] sub_layer_name = pattern_list[0].split('[')[0]
sub_layer_index = pattern_list[0].split('[')[1].split(']')[0] sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
sub_layer = getattr(sub_layer_parent, else:
sub_layer_name)[sub_layer_index] sub_layer_name = pattern_list[0]
if not isinstance(sub_layer, TheseusLayer): sub_layer_index = None
sub_layer = wrap_theseus(sub_layer)
sub_layer = getattr(sub_layer_parent, sub_layer_name, False)
if sub_layer is False:
msg = f"Not found sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})."
logger.warning(msg)
continue
try:
sub_layer = sub_layer[
sub_layer_index] if sub_layer_index is not None else sub_layer
except KeyError as e:
msg = f"Not found sub-layer by index({sub_layer_index}) specifed in pattern({pattern})."
logger.warning(msg)
continue
if not isinstance(sub_layer, TheseusLayer):
sub_layer = wrap_theseus(sub_layer)
if sub_layer_index:
getattr(sub_layer_parent, getattr(sub_layer_parent,
sub_layer_name)[sub_layer_index] = sub_layer sub_layer_name)[sub_layer_index] = sub_layer
else: else:
sub_layer = getattr(sub_layer_parent, pattern_list[0]) setattr(sub_layer_parent, sub_layer_name, sub_layer)
if not isinstance(sub_layer, TheseusLayer):
sub_layer = wrap_theseus(sub_layer)
setattr(sub_layer_parent, pattern_list[0], sub_layer)
handle_res = handle_func(sub_layer, pattern) handle_res = handle_func(sub_layer, pattern)
handle_res_dict[pattern] = handle_res handle_res_dict[pattern] = handle_res
@ -136,7 +151,7 @@ class TheseusLayer(nn.Layer):
return self._find_layers_handle( return self._find_layers_handle(
layer_name_pattern, handle_func=replace_function) layer_name_pattern, handle_func=replace_function)
# stop doesn't work when stop layer has a parallel branch. # TODO(weishengyu): stop doesn't work when stop layer has a parallel branch.
def stop_after(self, stop_layer_name: str) -> bool: def stop_after(self, stop_layer_name: str) -> bool:
"""stop forward and backward after 'stop_layer_name'. """stop forward and backward after 'stop_layer_name'.
@ -176,14 +191,15 @@ class TheseusLayer(nn.Layer):
return False return False
return True return True
def update_res(self, return_patterns: Union[str, List[str]]) -> bool: def update_res(self,
return_patterns: Union[str, List[str]]) -> Dict[str, bool]:
"""update the results needed returned. """update the results needed returned.
Args: Args:
return_patterns (Union[str, List[str]]): The layer(s)' name to be retruened. return_patterns (Union[str, List[str]]): [description]
Returns: Returns:
bool: 'True' if successful, 'False' otherwise. Dict[str, bool]: The pattern(str) is be set successfully if True(bool), failed otherwise.
""" """
class Handler(object): class Handler(object):
@ -194,6 +210,7 @@ class TheseusLayer(nn.Layer):
layer.res_dict = self.res_dict layer.res_dict = self.res_dict
layer.res_name = pattern layer.res_name = pattern
layer.register_forward_post_hook(layer._save_sub_res_hook) layer.register_forward_post_hook(layer._save_sub_res_hook)
return True
handle_func = Handler(self.res_dict) handle_func = Handler(self.res_dict)