refactor: rename replace_sub() func to ?
parent
0f126b75da
commit
56911b573b
ppcls/arch/backbone/base
|
@ -29,10 +29,16 @@ class TheseusLayer(nn.Layer):
|
|||
def _save_sub_res_hook(self, layer, input, output):
|
||||
self.res_dict[self.res_name] = output
|
||||
|
||||
def replace_sub(self,
|
||||
layer_name_pattern: Union[str, List[str]],
|
||||
handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[
|
||||
str, nn.Layer]:
|
||||
def replace_sub(self, *args, **kwargs) -> None:
|
||||
msg = "\"replace_sub\" is deprecated, please use \"layer_wrench\" instead."
|
||||
logger.error(DeprecationWarning(msg))
|
||||
raise DeprecationWarning(msg)
|
||||
|
||||
# TODO(gaotingquan): what is a good name?
|
||||
def layer_wrench(self,
|
||||
layer_name_pattern: Union[str, List[str]],
|
||||
handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[
|
||||
str, nn.Layer]:
|
||||
"""use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.
|
||||
|
||||
Args:
|
||||
|
@ -59,7 +65,7 @@ class TheseusLayer(nn.Layer):
|
|||
net = paddleclas.MobileNetV1()
|
||||
res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
|
||||
print(res)
|
||||
# {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True}
|
||||
# {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
|
||||
"""
|
||||
|
||||
if not isinstance(layer_name_pattern, list):
|
||||
|
|
Loading…
Reference in New Issue