refactor: rename replace_sub() func to ?

pull/1602/head
gaotingquan 2021-12-21 14:12:43 +00:00 committed by Tingquan Gao
parent 0f126b75da
commit 56911b573b
1 changed files with 11 additions and 5 deletions
ppcls/arch/backbone/base

View File

@ -29,10 +29,16 @@ class TheseusLayer(nn.Layer):
def _save_sub_res_hook(self, layer, input, output):
self.res_dict[self.res_name] = output
def replace_sub(self,
layer_name_pattern: Union[str, List[str]],
handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[
str, nn.Layer]:
def replace_sub(self, *args, **kwargs) -> None:
msg = "\"replace_sub\" is deprecated, please use \"layer_wrench\" instead."
logger.error(DeprecationWarning(msg))
raise DeprecationWarning(msg)
# TODO(gaotingquan): what is a good name?
def layer_wrench(self,
layer_name_pattern: Union[str, List[str]],
handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[
str, nn.Layer]:
"""use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.
Args:
@ -59,7 +65,7 @@ class TheseusLayer(nn.Layer):
net = paddleclas.MobileNetV1()
res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
print(res)
# {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True}
# {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
"""
if not isinstance(layer_name_pattern, list):