2021-12-20 14:34:11 +08:00
from typing import List , Dict , Union , Callable , Any
2021-05-24 11:42:24 +08:00
from paddle import nn
2021-12-17 22:37:18 +08:00
from ppcls . utils import logger
2021-05-24 11:42:24 +08:00
class Identity ( nn . Layer ) :
def __init__ ( self ) :
super ( Identity , self ) . __init__ ( )
def forward ( self , inputs ) :
return inputs
2021-05-25 16:47:33 +08:00
class TheseusLayer ( nn . Layer ) :
2021-08-08 01:07:49 +08:00
def __init__ ( self , * args , * * kwargs ) :
2021-05-24 11:42:24 +08:00
super ( TheseusLayer , self ) . __init__ ( )
2021-08-08 14:32:45 +08:00
self . res_dict = { }
2021-10-15 18:25:50 +08:00
self . res_name = self . full_name ( )
2021-12-09 14:51:40 +08:00
self . pruner = None
self . quanter = None
2021-05-24 11:42:24 +08:00
2021-12-17 22:37:18 +08:00
def _return_dict_hook ( self , layer , input , output ) :
res_dict = { " output " : output }
2021-12-21 15:50:14 +08:00
# 'list' is needed to avoid error raised by popping self.res_dict
2021-12-17 22:37:18 +08:00
for res_key in list ( self . res_dict ) :
res_dict [ res_key ] = self . res_dict . pop ( res_key )
return res_dict
def _save_sub_res_hook ( self , layer , input , output ) :
self . res_dict [ self . res_name ] = output
2021-12-21 15:50:14 +08:00
def replace_sub ( self ,
layer_name_pattern : Union [ str , List [ str ] ] ,
handle_func : Callable [ [ nn . Layer , str ] , nn . Layer ] ) - > Dict [
str , nn . Layer ] :
""" use ' handle_func ' to modify the sub-layer(s) specified by ' layer_name_pattern ' .
Args :
layer_name_pattern ( Union [ str , List [ str ] ] ) : The name of layer to be modified by ' handle_func ' .
handle_func ( Callable [ [ nn . Layer , str ] , nn . Layer ] ) : The function to modify target layer specified by ' layer_name_pattern ' .
Returns :
Dict [ str , nn . Layer ] : The key is the patter and corresponding value is the result returned by ' handle_func ' .
Examples :
from paddle import nn
import paddleclas
def rep_func ( sub_layer : nn . Layer , pattern : str ) :
new_layer = nn . Conv2D (
in_channels = sub_layer . _in_channels ,
out_channels = sub_layer . _out_channels ,
kernel_size = 5 ,
padding = 2
)
return new_layer
net = paddleclas . MobileNetV1 ( )
res = net . replace_sub ( layer_name_pattern = [ " blocks[11].depthwise_conv.conv " , " blocks[12].depthwise_conv.conv " ] , handle_func = rep_func )
print ( res )
# {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True}
"""
if not isinstance ( layer_name_pattern , list ) :
layer_name_pattern = [ layer_name_pattern ]
2021-12-20 12:33:20 +08:00
handle_res_dict = { }
2021-12-21 15:50:14 +08:00
for pattern in layer_name_pattern :
2021-12-21 17:19:39 +08:00
# pattern_list = pattern.split(".")
2021-12-20 12:33:20 +08:00
# find parent layer of sub-layer specified by pattern
2021-12-21 17:19:39 +08:00
sub_layer_parent , _ , _ = parse_pattern_str (
pattern = pattern , idx = ( 0 , - 1 ) , sub_layer_parent = self )
if not sub_layer_parent :
2021-10-15 18:25:50 +08:00
continue
2021-12-20 12:33:20 +08:00
# find sub-layer specified by pattern
2021-12-21 17:19:39 +08:00
sub_layer , sub_layer_name , sub_layer_index = parse_pattern_str (
pattern = pattern , idx = - 1 , sub_layer_parent = sub_layer_parent )
2021-12-20 14:34:11 +08:00
2021-12-21 15:50:14 +08:00
if not sub_layer :
2021-12-20 14:34:11 +08:00
continue
2021-12-21 15:50:14 +08:00
new_sub_layer = handle_func ( sub_layer , pattern )
2021-12-20 14:34:11 +08:00
if sub_layer_index :
2021-12-17 22:37:18 +08:00
getattr ( sub_layer_parent ,
2021-12-21 15:50:14 +08:00
sub_layer_name ) [ sub_layer_index ] = new_sub_layer
2021-08-08 15:55:34 +08:00
else :
2021-12-21 15:50:14 +08:00
setattr ( sub_layer_parent , sub_layer_name , new_sub_layer )
2021-10-15 18:25:50 +08:00
2021-12-21 15:50:14 +08:00
handle_res_dict [ pattern ] = new_sub_layer
2021-12-20 12:33:20 +08:00
return handle_res_dict
def _set_identity ( self , layer , layer_name , layer_index = None ) :
stop_after = False
for sub_layer_name in layer . _sub_layers :
if stop_after :
layer . _sub_layers [ sub_layer_name ] = Identity ( )
continue
if sub_layer_name == layer_name :
stop_after = True
if layer_index and stop_after :
stop_after = False
for sub_layer_index in layer . _sub_layers [ layer_name ] . _sub_layers :
if stop_after :
layer . _sub_layers [ layer_name ] [ sub_layer_index ] = Identity ( )
continue
if layer_index == sub_layer_index :
stop_after = True
return stop_after
2021-12-17 22:37:18 +08:00
def stop_after ( self , stop_layer_name : str ) - > bool :
""" stop forward and backward after ' stop_layer_name ' .
Args :
2021-12-20 14:55:40 +08:00
stop_layer_name ( str ) : The name of layer that stop forward and backward after this layer .
2021-09-15 11:03:46 +08:00
2021-12-17 22:37:18 +08:00
Returns :
bool : ' True ' if successful , ' False ' otherwise .
"""
pattern_list = stop_layer_name . split ( " . " )
to_identity_list = [ ]
2021-12-21 17:19:39 +08:00
# TODO(gaotingquan): replace code by self._parse_pattern_str()
2021-12-17 22:37:18 +08:00
layer = self
while len ( pattern_list ) > 0 :
layer_parent = layer
if ' [ ' in pattern_list [ 0 ] :
sub_layer_name = pattern_list [ 0 ] . split ( ' [ ' ) [ 0 ]
sub_layer_index = pattern_list [ 0 ] . split ( ' [ ' ) [ 1 ] . split ( ' ] ' ) [ 0 ]
layer = getattr ( layer , sub_layer_name ) [ sub_layer_index ]
else :
sub_layer_name = pattern_list [ 0 ]
sub_layer_index = None
layer = getattr ( layer , sub_layer_name , None )
if layer is None :
2021-12-20 12:33:20 +08:00
msg = f " Not found layer by name( { pattern_list [ 0 ] } ) specifed in stop_layer_name( { stop_layer_name } ). "
2021-12-17 22:37:18 +08:00
logger . warning ( msg )
return False
to_identity_list . append (
( layer_parent , sub_layer_name , sub_layer_index ) )
pattern_list = pattern_list [ 1 : ]
for to_identity_layer in to_identity_list :
if not self . _set_identity ( * to_identity_layer ) :
msg = " Failed to set the layers that after stop_layer_name to IdentityLayer. "
logger . warning ( msg )
return False
return True
2021-12-20 14:34:11 +08:00
def update_res ( self ,
return_patterns : Union [ str , List [ str ] ] ) - > Dict [ str , bool ] :
2021-12-20 14:55:40 +08:00
""" update the results to be returned.
2021-12-17 22:37:18 +08:00
Args :
2021-12-20 14:55:40 +08:00
return_patterns ( Union [ str , List [ str ] ] ) : The name of layer to return output .
2021-12-17 22:37:18 +08:00
Returns :
2021-12-20 14:55:40 +08:00
Dict [ str , bool ] : The pattern ( str ) is be set successfully if ' True ' ( bool ) , failed if ' False ' ( bool ) .
2021-12-17 22:37:18 +08:00
"""
class Handler ( object ) :
def __init__ ( self , res_dict ) :
self . res_dict = res_dict
def __call__ ( self , layer , pattern ) :
layer . res_dict = self . res_dict
layer . res_name = pattern
layer . register_forward_post_hook ( layer . _save_sub_res_hook )
2021-12-21 15:50:14 +08:00
return layer
2021-12-17 22:37:18 +08:00
handle_func = Handler ( self . res_dict )
2021-12-21 15:50:14 +08:00
return self . replace_sub ( return_patterns , handle_func = handle_func )
2021-08-08 14:57:29 +08:00
class WrapLayer ( TheseusLayer ) :
2021-10-15 18:25:50 +08:00
def __init__ ( self , sub_layer ) :
2021-08-08 14:57:29 +08:00
super ( WrapLayer , self ) . __init__ ( )
self . sub_layer = sub_layer
def forward ( self , * inputs , * * kwargs ) :
2021-08-08 15:56:57 +08:00
return self . sub_layer ( * inputs , * * kwargs )
2021-08-08 14:57:29 +08:00
2021-10-15 18:25:50 +08:00
def wrap_theseus ( sub_layer ) :
2021-12-21 17:19:39 +08:00
return WrapLayer ( sub_layer )
def unwrap_theseus ( sub_layer ) :
if isinstance ( sub_layer , WrapLayer ) :
sub_layer = sub_layer . sub_layer
return sub_layer
def slice_pattern ( pattern , idx ) :
pattern_list = pattern . split ( " . " )
if idx :
if isinstance ( idx , tuple ) :
if len ( idx ) == 1 :
return pattern_list [ idx [ 0 ] ]
elif len ( idx ) == 2 :
return pattern_list [ idx [ 0 ] : idx [ 1 ] ]
else :
msg = f " Only support length of ' idx ' is 1 or 2 when ' idx ' is a tuple. "
logger . warning ( msg )
return None
elif isinstance ( idx , int ) :
return [ pattern_list [ idx ] ]
else :
msg = f " Only support type of ' idx ' is int or tuple. "
logger . warning ( msg )
return None
return pattern_list
def parse_pattern_str ( pattern , sub_layer_parent , idx = None ) :
pattern_list = slice_pattern ( pattern , idx )
if not pattern_list :
return None , None , None
while len ( pattern_list ) > 0 :
if ' [ ' in pattern_list [ 0 ] :
sub_layer_name = pattern_list [ 0 ] . split ( ' [ ' ) [ 0 ]
sub_layer_index = pattern_list [ 0 ] . split ( ' [ ' ) [ 1 ] . split ( ' ] ' ) [ 0 ]
else :
sub_layer_name = pattern_list [ 0 ]
sub_layer_index = None
sub_layer_parent = getattr ( sub_layer_parent , sub_layer_name , None )
sub_layer_parent = unwrap_theseus ( sub_layer_parent )
if sub_layer_parent is None :
msg = f " Not found layer named( { sub_layer_name } ) specifed in pattern( { pattern } ). "
logger . warning ( msg )
return None , sub_layer_name , sub_layer_index
if sub_layer_index and sub_layer_parent :
if int ( sub_layer_index ) < 0 or int ( sub_layer_index ) > = len (
sub_layer_parent ) :
msg = f " Not found layer by index( { sub_layer_index } ) specifed in pattern( { pattern } ). The lenght of sub_layer ' s parent layer is < ' { len ( sub_layer_parent ) } ' and > ' 0 ' . "
logger . warning ( msg )
return None , sub_layer_name , sub_layer_index
sub_layer_parent = sub_layer_parent [ sub_layer_index ]
sub_layer_parent = unwrap_theseus ( sub_layer_parent )
pattern_list = pattern_list [ 1 : ]
return sub_layer_parent , sub_layer_name , sub_layer_index