refactor: extract _parse_pattern_str() func
parent
18dec0744a
commit
8d0b0d4b0a
|
@ -66,53 +66,22 @@ class TheseusLayer(nn.Layer):
|
|||
|
||||
handle_res_dict = {}
|
||||
for pattern in layer_name_pattern:
|
||||
pattern_list = pattern.split(".")
|
||||
# pattern_list = pattern.split(".")
|
||||
|
||||
# find parent layer of sub-layer specified by pattern
|
||||
sub_layer_parent = self
|
||||
while len(pattern_list) > 1:
|
||||
if '[' in pattern_list[0]:
|
||||
sub_layer_name = pattern_list[0].split('[')[0]
|
||||
sub_layer_index = pattern_list[0].split('[')[1].split(']')[
|
||||
0]
|
||||
sub_layer_parent = getattr(sub_layer_parent,
|
||||
sub_layer_name)[sub_layer_index]
|
||||
else:
|
||||
sub_layer_parent = getattr(sub_layer_parent,
|
||||
pattern_list[0], None)
|
||||
if sub_layer_parent is None:
|
||||
break
|
||||
if isinstance(sub_layer_parent, WrapLayer):
|
||||
sub_layer_parent = sub_layer_parent.sub_layer
|
||||
pattern_list = pattern_list[1:]
|
||||
if sub_layer_parent is None:
|
||||
msg = f"Not found parent layer of sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})."
|
||||
logger.warning(msg)
|
||||
sub_layer_parent, _, _ = parse_pattern_str(
|
||||
pattern=pattern, idx=(0, -1), sub_layer_parent=self)
|
||||
|
||||
if not sub_layer_parent:
|
||||
continue
|
||||
|
||||
# find sub-layer specified by pattern
|
||||
if '[' in pattern_list[0]:
|
||||
sub_layer_name = pattern_list[0].split('[')[0]
|
||||
sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
|
||||
else:
|
||||
sub_layer_name = pattern_list[0]
|
||||
sub_layer_index = None
|
||||
|
||||
sub_layer = getattr(sub_layer_parent, sub_layer_name, None)
|
||||
sub_layer, sub_layer_name, sub_layer_index = parse_pattern_str(
|
||||
pattern=pattern, idx=-1, sub_layer_parent=sub_layer_parent)
|
||||
|
||||
if not sub_layer:
|
||||
msg = f"Not found sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})."
|
||||
logger.warning(msg)
|
||||
continue
|
||||
|
||||
if sub_layer_index is not None:
|
||||
if int(sub_layer_index) < 0 or int(sub_layer_index) >= len(
|
||||
sub_layer):
|
||||
msg = f"Not found sub-layer by index({sub_layer_index}) specifed in pattern({pattern})."
|
||||
logger.warning(msg)
|
||||
continue
|
||||
sub_layer = sub_layer[sub_layer_index]
|
||||
|
||||
new_sub_layer = handle_func(sub_layer, pattern)
|
||||
|
||||
if sub_layer_index:
|
||||
|
@ -156,6 +125,7 @@ class TheseusLayer(nn.Layer):
|
|||
pattern_list = stop_layer_name.split(".")
|
||||
to_identity_list = []
|
||||
|
||||
# TODO(gaotingquan): replace code by self._parse_pattern_str()
|
||||
layer = self
|
||||
while len(pattern_list) > 0:
|
||||
layer_parent = layer
|
||||
|
@ -219,5 +189,67 @@ class WrapLayer(TheseusLayer):
|
|||
|
||||
|
||||
def wrap_theseus(sub_layer):
|
||||
wrapped_layer = WrapLayer(sub_layer)
|
||||
return wrapped_layer
|
||||
return WrapLayer(sub_layer)
|
||||
|
||||
|
||||
def unwrap_theseus(sub_layer):
|
||||
if isinstance(sub_layer, WrapLayer):
|
||||
sub_layer = sub_layer.sub_layer
|
||||
return sub_layer
|
||||
|
||||
|
||||
def slice_pattern(pattern, idx):
|
||||
pattern_list = pattern.split(".")
|
||||
if idx:
|
||||
if isinstance(idx, tuple):
|
||||
if len(idx) == 1:
|
||||
return pattern_list[idx[0]]
|
||||
elif len(idx) == 2:
|
||||
return pattern_list[idx[0]:idx[1]]
|
||||
else:
|
||||
msg = f"Only support length of 'idx' is 1 or 2 when 'idx' is a tuple."
|
||||
logger.warning(msg)
|
||||
return None
|
||||
elif isinstance(idx, int):
|
||||
return [pattern_list[idx]]
|
||||
else:
|
||||
msg = f"Only support type of 'idx' is int or tuple."
|
||||
logger.warning(msg)
|
||||
return None
|
||||
|
||||
return pattern_list
|
||||
|
||||
|
||||
def parse_pattern_str(pattern, sub_layer_parent, idx=None):
|
||||
pattern_list = slice_pattern(pattern, idx)
|
||||
if not pattern_list:
|
||||
return None, None, None
|
||||
|
||||
while len(pattern_list) > 0:
|
||||
if '[' in pattern_list[0]:
|
||||
sub_layer_name = pattern_list[0].split('[')[0]
|
||||
sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
|
||||
else:
|
||||
sub_layer_name = pattern_list[0]
|
||||
sub_layer_index = None
|
||||
|
||||
sub_layer_parent = getattr(sub_layer_parent, sub_layer_name, None)
|
||||
sub_layer_parent = unwrap_theseus(sub_layer_parent)
|
||||
|
||||
if sub_layer_parent is None:
|
||||
msg = f"Not found layer named({sub_layer_name}) specifed in pattern({pattern})."
|
||||
logger.warning(msg)
|
||||
return None, sub_layer_name, sub_layer_index
|
||||
|
||||
if sub_layer_index and sub_layer_parent:
|
||||
if int(sub_layer_index) < 0 or int(sub_layer_index) >= len(
|
||||
sub_layer_parent):
|
||||
msg = f"Not found layer by index({sub_layer_index}) specifed in pattern({pattern}). The lenght of sub_layer's parent layer is < '{len(sub_layer_parent)}' and > '0'."
|
||||
logger.warning(msg)
|
||||
return None, sub_layer_name, sub_layer_index
|
||||
sub_layer_parent = sub_layer_parent[sub_layer_index]
|
||||
sub_layer_parent = unwrap_theseus(sub_layer_parent)
|
||||
|
||||
pattern_list = pattern_list[1:]
|
||||
|
||||
return sub_layer_parent, sub_layer_name, sub_layer_index
|
||||
|
|
Loading…
Reference in New Issue