2021-12-21 21:51:45 +08:00
from typing import Tuple , List , Dict , Union , Callable , Any
2021-05-24 11:42:24 +08:00
from paddle import nn
2021-12-17 22:37:18 +08:00
from ppcls . utils import logger
2021-05-24 11:42:24 +08:00
class Identity ( nn . Layer ) :
def __init__ ( self ) :
super ( Identity , self ) . __init__ ( )
def forward ( self , inputs ) :
return inputs
2021-05-25 16:47:33 +08:00
class TheseusLayer ( nn . Layer ) :
2021-08-08 01:07:49 +08:00
def __init__ ( self , * args , * * kwargs ) :
2021-05-24 11:42:24 +08:00
super ( TheseusLayer , self ) . __init__ ( )
2021-08-08 14:32:45 +08:00
self . res_dict = { }
2021-10-15 18:25:50 +08:00
self . res_name = self . full_name ( )
2021-12-09 14:51:40 +08:00
self . pruner = None
self . quanter = None
2021-05-24 11:42:24 +08:00
2021-12-17 22:37:18 +08:00
def _return_dict_hook ( self , layer , input , output ) :
res_dict = { " output " : output }
2021-12-21 15:50:14 +08:00
# 'list' is needed to avoid error raised by popping self.res_dict
2021-12-17 22:37:18 +08:00
for res_key in list ( self . res_dict ) :
res_dict [ res_key ] = self . res_dict . pop ( res_key )
return res_dict
def _save_sub_res_hook ( self , layer , input , output ) :
self . res_dict [ self . res_name ] = output
2021-12-21 22:12:43 +08:00
def replace_sub ( self , * args , * * kwargs ) - > None :
msg = " \" replace_sub \" is deprecated, please use \" layer_wrench \" instead. "
logger . error ( DeprecationWarning ( msg ) )
raise DeprecationWarning ( msg )
# TODO(gaotingquan): what is a good name?
def layer_wrench ( self ,
layer_name_pattern : Union [ str , List [ str ] ] ,
handle_func : Callable [ [ nn . Layer , str ] , nn . Layer ] ) - > Dict [
str , nn . Layer ] :
2021-12-21 15:50:14 +08:00
""" use ' handle_func ' to modify the sub-layer(s) specified by ' layer_name_pattern ' .
Args :
layer_name_pattern ( Union [ str , List [ str ] ] ) : The name of layer to be modified by ' handle_func ' .
handle_func ( Callable [ [ nn . Layer , str ] , nn . Layer ] ) : The function to modify target layer specified by ' layer_name_pattern ' .
Returns :
Dict [ str , nn . Layer ] : The key is the patter and corresponding value is the result returned by ' handle_func ' .
Examples :
from paddle import nn
import paddleclas
def rep_func ( sub_layer : nn . Layer , pattern : str ) :
new_layer = nn . Conv2D (
in_channels = sub_layer . _in_channels ,
out_channels = sub_layer . _out_channels ,
kernel_size = 5 ,
padding = 2
)
return new_layer
net = paddleclas . MobileNetV1 ( )
res = net . replace_sub ( layer_name_pattern = [ " blocks[11].depthwise_conv.conv " , " blocks[12].depthwise_conv.conv " ] , handle_func = rep_func )
print ( res )
2021-12-21 22:12:43 +08:00
# {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
2021-12-21 15:50:14 +08:00
"""
2021-12-21 21:51:45 +08:00
2021-12-21 15:50:14 +08:00
if not isinstance ( layer_name_pattern , list ) :
layer_name_pattern = [ layer_name_pattern ]
2021-12-20 12:33:20 +08:00
handle_res_dict = { }
2021-12-21 15:50:14 +08:00
for pattern in layer_name_pattern :
2021-12-20 12:33:20 +08:00
# find parent layer of sub-layer specified by pattern
2021-12-21 21:51:45 +08:00
sub_layer_parent = None
for target_layer_dict in parse_pattern_str (
pattern = pattern , idx = ( 0 , - 1 ) , parent_layer = self ) :
sub_layer_parent = target_layer_dict [ " target_layer " ]
2021-12-21 17:19:39 +08:00
if not sub_layer_parent :
2021-10-15 18:25:50 +08:00
continue
2021-12-20 12:33:20 +08:00
# find sub-layer specified by pattern
2021-12-21 21:51:45 +08:00
sub_layer = None
for target_layer_dict in parse_pattern_str (
pattern = pattern , idx = - 1 , parent_layer = sub_layer_parent ) :
sub_layer = target_layer_dict [ " target_layer " ]
sub_layer_name = target_layer_dict [ " target_layer_name " ]
sub_layer_index = target_layer_dict [ " target_layer_index " ]
2021-12-20 14:34:11 +08:00
2021-12-21 15:50:14 +08:00
if not sub_layer :
2021-12-20 14:34:11 +08:00
continue
2021-12-21 15:50:14 +08:00
new_sub_layer = handle_func ( sub_layer , pattern )
2021-12-20 14:34:11 +08:00
if sub_layer_index :
2021-12-17 22:37:18 +08:00
getattr ( sub_layer_parent ,
2021-12-21 15:50:14 +08:00
sub_layer_name ) [ sub_layer_index ] = new_sub_layer
2021-08-08 15:55:34 +08:00
else :
2021-12-21 15:50:14 +08:00
setattr ( sub_layer_parent , sub_layer_name , new_sub_layer )
2021-10-15 18:25:50 +08:00
2021-12-21 15:50:14 +08:00
handle_res_dict [ pattern ] = new_sub_layer
2021-12-20 12:33:20 +08:00
return handle_res_dict
2021-12-17 22:37:18 +08:00
def stop_after ( self , stop_layer_name : str ) - > bool :
""" stop forward and backward after ' stop_layer_name ' .
Args :
2021-12-20 14:55:40 +08:00
stop_layer_name ( str ) : The name of layer that stop forward and backward after this layer .
2021-09-15 11:03:46 +08:00
2021-12-17 22:37:18 +08:00
Returns :
bool : ' True ' if successful , ' False ' otherwise .
"""
2021-12-21 21:51:45 +08:00
to_identity_list = [ ]
2021-12-17 22:37:18 +08:00
2021-12-21 21:51:45 +08:00
for target_layer_dict in parse_pattern_str ( stop_layer_name , self ) :
sub_layer_name = target_layer_dict [ " target_layer_name " ]
sub_layer_index = target_layer_dict [ " target_layer_index " ]
parent_layer = target_layer_dict [ " parent_layer " ]
2021-12-17 22:37:18 +08:00
to_identity_list . append (
2021-12-21 21:51:45 +08:00
( parent_layer , sub_layer_name , sub_layer_index ) )
2021-12-17 22:37:18 +08:00
for to_identity_layer in to_identity_list :
2021-12-21 21:51:45 +08:00
if not set_identity ( * to_identity_layer ) :
2021-12-17 22:37:18 +08:00
msg = " Failed to set the layers that after stop_layer_name to IdentityLayer. "
logger . warning ( msg )
return False
return True
2021-12-20 14:34:11 +08:00
def update_res ( self ,
return_patterns : Union [ str , List [ str ] ] ) - > Dict [ str , bool ] :
2021-12-20 14:55:40 +08:00
""" update the results to be returned.
2021-12-17 22:37:18 +08:00
Args :
2021-12-20 14:55:40 +08:00
return_patterns ( Union [ str , List [ str ] ] ) : The name of layer to return output .
2021-12-17 22:37:18 +08:00
Returns :
2021-12-20 14:55:40 +08:00
Dict [ str , bool ] : The pattern ( str ) is be set successfully if ' True ' ( bool ) , failed if ' False ' ( bool ) .
2021-12-17 22:37:18 +08:00
"""
class Handler ( object ) :
def __init__ ( self , res_dict ) :
self . res_dict = res_dict
def __call__ ( self , layer , pattern ) :
layer . res_dict = self . res_dict
layer . res_name = pattern
layer . register_forward_post_hook ( layer . _save_sub_res_hook )
2021-12-21 15:50:14 +08:00
return layer
2021-12-17 22:37:18 +08:00
handle_func = Handler ( self . res_dict )
2021-12-21 15:50:14 +08:00
return self . replace_sub ( return_patterns , handle_func = handle_func )
2021-08-08 14:57:29 +08:00
class WrapLayer ( TheseusLayer ) :
2021-10-15 18:25:50 +08:00
def __init__ ( self , sub_layer ) :
2021-08-08 14:57:29 +08:00
super ( WrapLayer , self ) . __init__ ( )
self . sub_layer = sub_layer
def forward ( self , * inputs , * * kwargs ) :
2021-08-08 15:56:57 +08:00
return self . sub_layer ( * inputs , * * kwargs )
2021-08-08 14:57:29 +08:00
2021-10-15 18:25:50 +08:00
def wrap_theseus ( sub_layer ) :
2021-12-21 17:19:39 +08:00
return WrapLayer ( sub_layer )
def unwrap_theseus ( sub_layer ) :
if isinstance ( sub_layer , WrapLayer ) :
sub_layer = sub_layer . sub_layer
return sub_layer
2021-12-21 21:51:45 +08:00
def set_identity ( parent_layer : nn . Layer ,
layer_name : str ,
layer_index : str = None ) - > bool :
""" set the layer specified by layer_name and layer_index to Indentity.
Args :
parent_layer ( nn . Layer ) : The parent layer of target layer specified by layer_name and layer_index .
layer_name ( str ) : The name of target layer to be set to Indentity .
layer_index ( str , optional ) : The index of target layer to be set to Indentity in parent_layer . Defaults to None .
Returns :
bool : True if successfully , False otherwise .
"""
stop_after = False
for sub_layer_name in parent_layer . _sub_layers :
if stop_after :
parent_layer . _sub_layers [ sub_layer_name ] = Identity ( )
continue
if sub_layer_name == layer_name :
stop_after = True
if layer_index and stop_after :
stop_after = False
for sub_layer_index in parent_layer . _sub_layers [
layer_name ] . _sub_layers :
if stop_after :
parent_layer . _sub_layers [ layer_name ] [
sub_layer_index ] = Identity ( )
continue
if layer_index == sub_layer_index :
stop_after = True
return stop_after
def slice_pattern ( pattern : str , idx : Union [ Tuple , int ] = None ) - > List :
""" slice the string type " pattern " to list type by separator " . " .
Args :
pattern ( str ) : The pattern to discribe layer name .
idx ( Union [ Tuple , int ] , optional ) : The index ( s ) of sub - list of list sliced . Defaults to None .
Returns :
List : The sub - list of list sliced by " pattern " .
"""
2021-12-21 17:19:39 +08:00
pattern_list = pattern . split ( " . " )
if idx :
2021-12-21 21:51:45 +08:00
if isinstance ( idx , Tuple ) :
2021-12-21 17:19:39 +08:00
if len ( idx ) == 1 :
return pattern_list [ idx [ 0 ] ]
elif len ( idx ) == 2 :
return pattern_list [ idx [ 0 ] : idx [ 1 ] ]
else :
2021-12-21 21:51:45 +08:00
msg = f " Only support length of ' idx ' is 1 or 2 when ' idx ' is a Tuple. "
2021-12-21 17:19:39 +08:00
logger . warning ( msg )
return None
elif isinstance ( idx , int ) :
return [ pattern_list [ idx ] ]
else :
2021-12-21 21:51:45 +08:00
msg = f " Only support type of ' idx ' is int or Tuple. "
2021-12-21 17:19:39 +08:00
logger . warning ( msg )
return None
return pattern_list
2021-12-21 21:51:45 +08:00
def parse_pattern_str ( pattern : str , parent_layer : nn . Layer ,
idx = None ) - > Dict [ str , Union [ nn . Layer , None , str ] ] :
""" parse the string type pattern.
Args :
pattern ( str ) : The pattern to discribe layer name .
parent_layer ( nn . Layer ) : The parent layer of target layer ( s ) specified by " pattern " .
idx ( [ type ] , optional ) : [ description ] . The index ( s ) of sub - list of list sliced . Defaults to None .
Returns :
Dict [ str , Union [ nn . Layer , None , str ] ] : Dict [ " target_layer " : Union [ nn . Layer , None ] , " target_layer_name " : str , " target_layer_index " : str , " parent_layer " : nn . Layer ]
Yields :
Iterator [ Dict [ str , Union [ nn . Layer , None , str ] ] ] : Dict [ " target_layer " : Union [ nn . Layer , None ] , " target_layer_name " : str , " target_layer_index " : str , " parent_layer " : nn . Layer ]
"""
2021-12-21 17:19:39 +08:00
pattern_list = slice_pattern ( pattern , idx )
if not pattern_list :
return None , None , None
while len ( pattern_list ) > 0 :
if ' [ ' in pattern_list [ 0 ] :
2021-12-21 21:51:45 +08:00
target_layer_name = pattern_list [ 0 ] . split ( ' [ ' ) [ 0 ]
target_layer_index = pattern_list [ 0 ] . split ( ' [ ' ) [ 1 ] . split ( ' ] ' ) [ 0 ]
2021-12-21 17:19:39 +08:00
else :
2021-12-21 21:51:45 +08:00
target_layer_name = pattern_list [ 0 ]
target_layer_index = None
2021-12-21 17:19:39 +08:00
2021-12-21 21:51:45 +08:00
target_layer = getattr ( parent_layer , target_layer_name , None )
target_layer = unwrap_theseus ( target_layer )
2021-12-21 17:19:39 +08:00
2021-12-21 21:51:45 +08:00
if target_layer is None :
msg = f " Not found layer named( { target_layer_name } ) specifed in pattern( { pattern } ). "
2021-12-21 17:19:39 +08:00
logger . warning ( msg )
2021-12-21 21:51:45 +08:00
return {
" target_layer " : None ,
" target_layer_name " : target_layer_name ,
" target_layer_index " : target_layer_index ,
" parent_layer " : parent_layer
}
if target_layer_index and target_layer :
if int ( target_layer_index ) < 0 or int ( target_layer_index ) > = len (
target_layer ) :
msg = f " Not found layer by index( { target_layer_index } ) specifed in pattern( { pattern } ). The lenght of sub_layer ' s parent layer is < ' { len ( parent_layer ) } ' and > ' 0 ' . "
2021-12-21 17:19:39 +08:00
logger . warning ( msg )
2021-12-21 21:51:45 +08:00
return {
" target_layer " : None ,
" target_layer_name " : target_layer_name ,
" target_layer_index " : target_layer_index ,
" parent_layer " : parent_layer
}
target_layer = target_layer [ target_layer_index ]
target_layer = unwrap_theseus ( target_layer )
yield {
" target_layer " : target_layer ,
" target_layer_name " : target_layer_name ,
" target_layer_index " : target_layer_index ,
" parent_layer " : parent_layer
}
2021-12-21 17:19:39 +08:00
pattern_list = pattern_list [ 1 : ]
2021-12-21 21:51:45 +08:00
parent_layer = target_layer