diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index 35eac5f08..64bfed0e9 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -15,6 +15,7 @@ class TheseusLayer(nn.Layer): def __init__(self, *args, **kwargs): super(TheseusLayer, self).__init__() self.res_dict = {} + self.res_name = self.full_name() # stop doesn't work when stop layer has a parallel branch. def stop_after(self, stop_layer_name: str): @@ -33,29 +34,45 @@ class TheseusLayer(nn.Layer): return after_stop def update_res(self, return_patterns): - if not return_patterns or isinstance(self, WrapLayer): - return - for layer_i in self._sub_layers: - layer_name = self._sub_layers[layer_i].full_name() - if isinstance(self._sub_layers[layer_i], (nn.Sequential, nn.LayerList)): - self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i], self.res_dict) - self._sub_layers[layer_i].update_res(return_patterns) + for return_pattern in return_patterns: + pattern_list = return_pattern.split(".") + if not pattern_list: + continue + sub_layer_parent = self + while len(pattern_list) > 1: + if '[' in pattern_list[0]: + sub_layer_name = pattern_list[0].split('[')[0] + sub_layer_index = pattern_list[0].split('[')[1].split(']')[0] + sub_layer_parent = getattr(sub_layer_parent, sub_layer_name)[sub_layer_index] + else: + sub_layer_parent = getattr(sub_layer_parent, pattern_list[0], + None) + if sub_layer_parent is None: + break + if isinstance(sub_layer_parent, WrapLayer): + sub_layer_parent = sub_layer_parent.sub_layer + pattern_list = pattern_list[1:] + if sub_layer_parent is None: + continue + if '[' in pattern_list[0]: + sub_layer_name = pattern_list[0].split('[')[0] + sub_layer_index = pattern_list[0].split('[')[1].split(']')[0] + sub_layer = getattr(sub_layer_parent, sub_layer_name)[sub_layer_index] + if not isinstance(sub_layer, TheseusLayer): + sub_layer = wrap_theseus(sub_layer) + getattr(sub_layer_parent, sub_layer_name)[sub_layer_index] = sub_layer else: - for return_pattern in return_patterns: - if re.match(return_pattern, layer_name): - if not isinstance(self._sub_layers[layer_i], TheseusLayer): - self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i], self.res_dict) - else: - self._sub_layers[layer_i].res_dict = self.res_dict + sub_layer = getattr(sub_layer_parent, pattern_list[0]) + if not isinstance(sub_layer, TheseusLayer): + sub_layer = wrap_theseus(sub_layer) + setattr(sub_layer_parent, pattern_list[0], sub_layer) - self._sub_layers[layer_i].register_forward_post_hook( - self._sub_layers[layer_i]._save_sub_res_hook) - if isinstance(self._sub_layers[layer_i], TheseusLayer): - self._sub_layers[layer_i].res_dict = self.res_dict - self._sub_layers[layer_i].update_res(return_patterns) + sub_layer.res_dict = self.res_dict + sub_layer.res_name = return_pattern + sub_layer.register_forward_post_hook(sub_layer._save_sub_res_hook) def _save_sub_res_hook(self, layer, input, output): - self.res_dict[layer.full_name()] = output + self.res_dict[self.res_name] = output def _return_dict_hook(self, layer, input, output): res_dict = {"output": output} @@ -63,19 +80,23 @@ class TheseusLayer(nn.Layer): res_dict[res_key] = self.res_dict.pop(res_key) return res_dict - def replace_sub(self, layer_name_pattern, replace_function, recursive=True): + def replace_sub(self, layer_name_pattern, replace_function, + recursive=True): for layer_i in self._sub_layers: layer_name = self._sub_layers[layer_i].full_name() if re.match(layer_name_pattern, layer_name): - self._sub_layers[layer_i] = replace_function(self._sub_layers[layer_i]) + self._sub_layers[layer_i] = replace_function(self._sub_layers[ + layer_i]) if recursive: if isinstance(self._sub_layers[layer_i], TheseusLayer): self._sub_layers[layer_i].replace_sub( layer_name_pattern, replace_function, recursive) - elif isinstance(self._sub_layers[layer_i], (nn.Sequential, nn.LayerList)): + elif isinstance(self._sub_layers[layer_i], + (nn.Sequential, nn.LayerList)): for layer_j in self._sub_layers[layer_i]._sub_layers: - self._sub_layers[layer_i]._sub_layers[layer_j].replace_sub( - layer_name_pattern, replace_function, recursive) + self._sub_layers[layer_i]._sub_layers[ + layer_j].replace_sub(layer_name_pattern, + replace_function, recursive) ''' example of replace function: @@ -92,39 +113,14 @@ class TheseusLayer(nn.Layer): class WrapLayer(TheseusLayer): - def __init__(self, sub_layer, res_dict=None): + def __init__(self, sub_layer): super(WrapLayer, self).__init__() self.sub_layer = sub_layer - self.name = sub_layer.full_name() - if res_dict is not None: - self.res_dict = res_dict - - def full_name(self): - return self.name def forward(self, *inputs, **kwargs): return self.sub_layer(*inputs, **kwargs) - def update_res(self, return_patterns): - if not return_patterns or not isinstance(self.sub_layer, (nn.Sequential, nn.LayerList)): - return - for layer_i in self.sub_layer._sub_layers: - if isinstance(self.sub_layer._sub_layers[layer_i], (nn.Sequential, nn.LayerList)): - self.sub_layer._sub_layers[layer_i] = wrap_theseus(self.sub_layer._sub_layers[layer_i], self.res_dict) - self.sub_layer._sub_layers[layer_i].update_res(return_patterns) - elif isinstance(self.sub_layer._sub_layers[layer_i], TheseusLayer): - self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict - layer_name = self.sub_layer._sub_layers[layer_i].full_name() - for return_pattern in return_patterns: - if re.match(return_pattern, layer_name): - self.sub_layer._sub_layers[layer_i].register_forward_post_hook( - self._sub_layers[layer_i]._save_sub_res_hook) - - if isinstance(self.sub_layer._sub_layers[layer_i], TheseusLayer): - self.sub_layer._sub_layers[layer_i].update_res(return_patterns) - - -def wrap_theseus(sub_layer, res_dict=None): - wrapped_layer = WrapLayer(sub_layer, res_dict) +def wrap_theseus(sub_layer): + wrapped_layer = WrapLayer(sub_layer) return wrapped_layer diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index 347ff3130..4b026a828 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -37,22 +37,22 @@ def train_epoch(engine, epoch_id, print_batch_step): batch[1] = batch[1].reshape([-1, 1]).astype("int64") engine.global_step += 1 + if engine.config["DataLoader"]["Train"]["dataset"].get( + "batch_transform_ops", None): + gt_input = batch[1:] + else: + gt_input = batch[1] + # image input if engine.amp: with paddle.amp.auto_cast(custom_black_list={ "flatten_contiguous_range", "greater_than" }): out = forward(engine, batch) - loss_dict = engine.train_loss_func(out, batch[1]) + loss_dict = engine.train_loss_func(out, gt_input) else: out = forward(engine, batch) - - # calc loss - if engine.config["DataLoader"]["Train"]["dataset"].get( - "batch_transform_ops", None): - loss_dict = engine.train_loss_func(out, batch[1:]) - else: - loss_dict = engine.train_loss_func(out, batch[1]) + loss_dict = engine.train_loss_func(out, gt_input) # step opt and lr if engine.amp: