From 877c8c53be608eec952483883c72c80402843d3a Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 24 Aug 2022 07:25:30 +0000 Subject: [PATCH] correct config yaml, pp_lcnetv2_variant.py and log in&out channels in pp_lcnet_v2.py --- ppcls/arch/backbone/base/theseus_layer.py | 61 +++++++------ .../backbone/legendary_models/pp_lcnet_v2.py | 2 + .../variant_models/pp_lcnetv2_variant.py | 62 +++++++------ .../GeneralRecognitionV2_PPLCNetV2_base.yaml | 16 ++-- ppcls/data/preprocess/ops/test_pad.py | 90 ------------------- 5 files changed, 81 insertions(+), 150 deletions(-) delete mode 100644 ppcls/data/preprocess/ops/test_pad.py diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index 6a4d6c0af..192077c11 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -103,7 +103,7 @@ class TheseusLayer(nn.Layer): return new_layer net = paddleclas.MobileNetV1() - res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func) + res = net.upgrade_sublayer(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func) print(res) # {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer} """ @@ -122,13 +122,21 @@ class TheseusLayer(nn.Layer): sub_layer = layer_list[-1]["layer"] sub_layer_name = layer_list[-1]["name"] - sub_layer_index = layer_list[-1]["index"] + sub_layer_index_list = layer_list[-1]["index_list"] new_sub_layer = handle_func(sub_layer, pattern) - if sub_layer_index: - getattr(sub_layer_parent, - sub_layer_name)[sub_layer_index] = new_sub_layer + if sub_layer_index_list: + if len(sub_layer_index_list) > 1: + sub_layer_parent = getattr( + sub_layer_parent, + sub_layer_name)[sub_layer_index_list[0]] + for sub_layer_index in sub_layer_index_list[1:-1]: + sub_layer_parent = sub_layer_parent[sub_layer_index] + sub_layer_parent[sub_layer_index_list[-1]] = new_sub_layer + else: + getattr(sub_layer_parent, sub_layer_name)[ + sub_layer_index_list[0]] = new_sub_layer else: setattr(sub_layer_parent, sub_layer_name, new_sub_layer) @@ -151,15 +159,13 @@ class TheseusLayer(nn.Layer): parent_layer = self for layer_dict in layer_list: - name, index = layer_dict["name"], layer_dict["index"] - if not set_identity(parent_layer, name, index): + name, index_list = layer_dict["name"], layer_dict["index_list"] + if not set_identity(parent_layer, name, index_list): msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'." logger.warning(msg) return False parent_layer = layer_dict["layer"] - msg = f"Successfully set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer." - logger.info(msg) return True def update_res( @@ -208,15 +214,14 @@ def save_sub_res_hook(layer, input, output): layer.res_dict[layer.res_name] = output -def set_identity(parent_layer: nn.Layer, - layer_name: str, - layer_index: str=None) -> bool: - """set the layer specified by layer_name and layer_index to Indentity. +def set_identity(parent_layer: nn.Layer, layer_name: str, + index_list: str=None) -> bool: + """set the layer specified by layer_name and index_list to Indentity. Args: - parent_layer (nn.Layer): The parent layer of target layer specified by layer_name and layer_index. + parent_layer (nn.Layer): The parent layer of target layer specified by layer_name and index_list. layer_name (str): The name of target layer to be set to Indentity. - layer_index (str, optional): The index of target layer to be set to Indentity in parent_layer. Defaults to None. + index_list (str, optional): The index of target layer to be set to Indentity in parent_layer. Defaults to None. Returns: bool: True if successfully, False otherwise. @@ -230,7 +235,7 @@ def set_identity(parent_layer: nn.Layer, if sub_layer_name == layer_name: stop_after = True - if layer_index and stop_after: + if index_list and stop_after: stop_after = False for sub_layer_index in parent_layer._sub_layers[ layer_name]._sub_layers: @@ -271,10 +276,12 @@ def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[ while len(pattern_list) > 0: if '[' in pattern_list[0]: target_layer_name = pattern_list[0].split('[')[0] - target_layer_index = pattern_list[0].split('[')[1].split(']')[0] + target_layer_index_list = list( + index.split(']')[0] + for index in pattern_list[0].split('[')[1:]) else: target_layer_name = pattern_list[0] - target_layer_index = None + target_layer_index_list = None target_layer = getattr(parent_layer, target_layer_name, None) @@ -283,19 +290,19 @@ def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[ logger.warning(msg) return None - if target_layer_index and target_layer: - if int(target_layer_index) < 0 or int(target_layer_index) >= len( - target_layer): - msg = f"Not found layer by index('{target_layer_index}') specifed in pattern('{pattern}'). The index should < {len(target_layer)} and > 0." - logger.warning(msg) - return None - - target_layer = target_layer[target_layer_index] + if target_layer_index_list: + for target_layer_index in target_layer_index_list: + if int(target_layer_index) < 0 or int( + target_layer_index) >= len(target_layer): + msg = f"Not found layer by index('{target_layer_index}') specifed in pattern('{pattern}'). The index should < {len(target_layer)} and > 0." + logger.warning(msg) + return None + target_layer = target_layer[target_layer_index] layer_list.append({ "layer": target_layer, "name": target_layer_name, - "index": target_layer_index + "index_list": target_layer_index_list }) pattern_list = pattern_list[1:] diff --git a/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py b/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py index b48d33e05..ea24489c1 100644 --- a/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py +++ b/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py @@ -126,6 +126,8 @@ class RepDepthwiseSeparable(TheseusLayer): use_se=False, use_shortcut=False): super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels self.is_repped = False self.dw_size = dw_size diff --git a/ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py b/ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py index 73c76993d..6acccdc8e 100644 --- a/ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py +++ b/ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py @@ -1,43 +1,55 @@ from paddle.nn import Conv2D, Identity -from ..legendary_models.pp_lcnet_v2 import PPLCNetV2_base, RepDepthwiseSeparable, MODEL_URLS, _load_pretrained + +from ..legendary_models.pp_lcnet_v2 import MODEL_URLS, PPLCNetV2_base, RepDepthwiseSeparable, _load_pretrained __all__ = ["PPLCNetV2_base_ShiTu"] def PPLCNetV2_base_ShiTu(pretrained=False, use_ssld=False, **kwargs): - + """ + An variant network of PPLCNetV2_base + 1. remove ReLU layer after last_conv + 2. add bias to last_conv + 3. change stride to 1 in last two RepDepthwiseSeparable Block + """ model = PPLCNetV2_base(pretrained=False, use_ssld=use_ssld, **kwargs) def remove_ReLU_function(conv, pattern): new_conv = Identity() return new_conv - # def last_stride_function(conv, pattern): - # new_conv = Conv2D( - # weight_attr=conv._param_attr, - # in_channels=conv._in_channels, - # out_channels=conv._out_channels, - # kernel_size=conv._kernel_size, - # stride=1, - # padding=conv._padding, - # groups=conv._groups, - # bias_attr=conv._bias_attr) - # return new_conv + def add_bias_last_conv(conv, pattern): + new_conv = Conv2D( + in_channels=conv._in_channels, + out_channels=conv._out_channels, + kernel_size=conv._kernel_size, + stride=conv._stride, + padding=conv._padding, + groups=conv._groups, + bias_attr=True) + return new_conv + + def last_stride_function(rep_block, pattern): + new_conv = RepDepthwiseSeparable( + in_channels=rep_block.in_channels, + out_channels=rep_block.out_channels, + stride=1, + dw_size=rep_block.dw_size, + split_pw=rep_block.split_pw, + use_rep=rep_block.use_rep, + use_se=rep_block.use_se, + use_shortcut=rep_block.use_shortcut) + return new_conv pattern_act = ["act"] - # pattern_last_stride = [ - # "stages[3][0].dw_conv_list[0].conv", - # "stages[3][0].dw_conv_list[1].conv", - # "stages[3][0].dw_conv", - # "stages[3][0].pw_conv.conv", - # "stages[3][1].dw_conv_list[0].conv", - # "stages[3][1].dw_conv_list[1].conv", - # "stages[3][1].dw_conv_list[2].conv", - # "stages[3][1].dw_conv", - # "stages[3][1].pw_conv.conv", - # ] - # model.upgrade_sublayer(pattern_last_stride, last_stride_function) # TODO: theseuslayer有BUG,暂时注释掉 + pattern_lastconv = ["last_conv"] + pattern_last_stride = [ + "stages[3][0]", + "stages[3][1]", + ] model.upgrade_sublayer(pattern_act, remove_ReLU_function) + model.upgrade_sublayer(pattern_lastconv, add_bias_last_conv) + model.upgrade_sublayer(pattern_last_stride, last_stride_function) # load params again after upgrade some layers _load_pretrained(pretrained, model, MODEL_URLS["PPLCNetV2_base"], use_ssld) diff --git a/ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml b/ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml index 7e2647cec..e6dfde7cd 100644 --- a/ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml +++ b/ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml @@ -18,11 +18,11 @@ Global: image_shape: [3, 224, 224] save_inference_dir: ./inference -# AMP: -# scale_loss: 65536 -# use_dynamic_loss_scaling: True -# # O1: mixed fp16 -# level: O1 +AMP: + scale_loss: 65536 + use_dynamic_loss_scaling: True + # O1: mixed fp16 + level: O1 # model architecture Arch: @@ -96,7 +96,7 @@ DataLoader: dataset: name: ImageNetDataset image_root: ./dataset/ - cls_label_path: ./dataset/train_reg_all_data.txt + cls_label_path: ./dataset/train_reg_all_data_v2.txt relabel: True transform_ops: - DecodeImage: @@ -130,12 +130,12 @@ DataLoader: order: hwc sampler: name: PKSampler - batch_size: 8 + batch_size: 256 sample_per_id: 4 drop_last: False shuffle: True sample_method: "id_avg_prob" - id_list: [50030, 80700, 92019, 96015] + id_list: [50030, 80700, 92019, 96015] # be careful when set relabel=True ratio: [4, 4] loader: num_workers: 4 diff --git a/ppcls/data/preprocess/ops/test_pad.py b/ppcls/data/preprocess/ops/test_pad.py deleted file mode 100644 index 56ecb20a4..000000000 --- a/ppcls/data/preprocess/ops/test_pad.py +++ /dev/null @@ -1,90 +0,0 @@ -import numpy as np - -import paddle.vision.transforms as T -import cv2 - - -class Pad(object): - """ - Pads the given PIL.Image on all sides with specified padding mode and fill value. - adapted from: https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#Pad - """ - - def __init__(self, - padding: int, - fill: int=0, - padding_mode: str="constant", - backend: str="pil"): - self.padding = padding - self.fill = fill - self.padding_mode = padding_mode - self.backend = backend - assert backend in [ - "pil", "cv2" - ], f"backend in Pad must in ['pil', 'cv2'], but got {backend}" - - def _parse_fill(self, fill, img, min_pil_version, name="fillcolor"): - # Process fill color for affine transforms - major_found, minor_found = (int(v) - for v in PILLOW_VERSION.split('.')[:2]) - major_required, minor_required = (int(v) for v in - min_pil_version.split('.')[:2]) - if major_found < major_required or (major_found == major_required and - minor_found < minor_required): - if fill is None: - return {} - else: - msg = ( - "The option to fill background area of the transformed image, " - "requires pillow>={}") - raise RuntimeError(msg.format(min_pil_version)) - - num_bands = len(img.getbands()) - if fill is None: - fill = 0 - if isinstance(fill, (int, float)) and num_bands > 1: - fill = tuple([fill] * num_bands) - if isinstance(fill, (list, tuple)): - if len(fill) != num_bands: - msg = ( - "The number of elements in 'fill' does not match the number of " - "bands of the image ({} != {})") - raise ValueError(msg.format(len(fill), num_bands)) - - fill = tuple(fill) - - return {name: fill} - - def __call__(self, img): - if self.backend == "pil": - opts = self._parse_fill(self.fill, img, "2.3.0", name="fill") - if img.mode == "P": - palette = img.getpalette() - img = ImageOps.expand(img, border=self.padding, **opts) - img.putpalette(palette) - return img - return ImageOps.expand(img, border=self.padding, **opts) - else: - img = cv2.copyMakeBorder( - img, - self.padding, - self.padding, - self.padding, - self.padding, - cv2.BORDER_CONSTANT, - value=(self.fill, self.fill, self.fill)) - return img - - -img = np.random.randint(0, 255, [3, 4, 3], dtype=np.uint8) - -for p in range(0, 10): - for v in range(0, 10): - img_1 = Pad(p, v, backend="cv2")(img) - img_2 = T.Pad(p, (v, v, v))(img) - print(f"{p} - {v}", np.allclose(img_1, img_2)) - if not np.allclose(img_1, img_2): - print(img_1[..., 0], "\n", img_2[..., 0]) - print(img_1[..., 1], "\n", img_2[..., 1]) - print(img_1[..., 2], "\n", img_2[..., 2]) - exit(0)