refactor: rename replace_sub() func to ?
parent
0f126b75da
commit
56911b573b
|
@ -29,10 +29,16 @@ class TheseusLayer(nn.Layer):
|
||||||
def _save_sub_res_hook(self, layer, input, output):
|
def _save_sub_res_hook(self, layer, input, output):
|
||||||
self.res_dict[self.res_name] = output
|
self.res_dict[self.res_name] = output
|
||||||
|
|
||||||
def replace_sub(self,
|
def replace_sub(self, *args, **kwargs) -> None:
|
||||||
layer_name_pattern: Union[str, List[str]],
|
msg = "\"replace_sub\" is deprecated, please use \"layer_wrench\" instead."
|
||||||
handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[
|
logger.error(DeprecationWarning(msg))
|
||||||
str, nn.Layer]:
|
raise DeprecationWarning(msg)
|
||||||
|
|
||||||
|
# TODO(gaotingquan): what is a good name?
|
||||||
|
def layer_wrench(self,
|
||||||
|
layer_name_pattern: Union[str, List[str]],
|
||||||
|
handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[
|
||||||
|
str, nn.Layer]:
|
||||||
"""use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.
|
"""use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -59,7 +65,7 @@ class TheseusLayer(nn.Layer):
|
||||||
net = paddleclas.MobileNetV1()
|
net = paddleclas.MobileNetV1()
|
||||||
res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
|
res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
|
||||||
print(res)
|
print(res)
|
||||||
# {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True}
|
# {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(layer_name_pattern, list):
|
if not isinstance(layer_name_pattern, list):
|
||||||
|
|
Loading…
Reference in New Issue