diff --git a/deploy/configs/inference_product_binary.yaml b/deploy/configs/inference_general_binary.yaml similarity index 64% rename from deploy/configs/inference_product_binary.yaml rename to deploy/configs/inference_general_binary.yaml index aefe0ba6a..d76dae8f8 100644 --- a/deploy/configs/inference_product_binary.yaml +++ b/deploy/configs/inference_general_binary.yaml @@ -1,7 +1,7 @@ Global: - infer_imgs: "./recognition_demo_data_v1.1/test_product/daoxiangcunjinzhubing_6.jpg" - det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer" - rec_inference_model_dir: "./models/product_MV3_x1_0_aliproduct_bin_v1.0_infer" + infer_imgs: "./drink_dataset_v1.0/test_images/001.jpeg" + det_inference_model_dir: "./models/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer" + rec_inference_model_dir: "./models/general_PPLCNet_x2_5_lite_binary_v1.0_infer" rec_nms_thresold: 0.05 batch_size: 1 @@ -11,7 +11,6 @@ Global: labe_list: - foreground - # inference engine config use_gpu: True enable_mkldnn: True cpu_num_threads: 10 @@ -49,19 +48,18 @@ RecPreProcess: RecPostProcess: main_indicator: Binarize Binarize: - method: "round" + method: "sign" # indexing engine config IndexProcess: - index_method: "Flat" # supported: HNSW32, Flat - index_dir: "./recognition_demo_data_v1.1/gallery_product/index_binary" - image_root: "./recognition_demo_data_v1.1/gallery_product/" - data_file: "./recognition_demo_data_v1.1/gallery_product/data_file.txt" + index_method: "Flat" # supported: HNSW32, Flat + image_root: "./drink_dataset_v1.0/gallery/" + index_dir: "./drink_dataset_v1.0/index_bin" + data_file: "./drink_dataset_v1.0/gallery/drink_label.txt" index_operation: "new" # suported: "append", "remove", "new" delimiter: "\t" dist_type: "hamming" embedding_size: 512 batch_size: 32 - binary_index: true return_k: 5 - score_thres: 0 \ No newline at end of file + hamming_radius: 100 diff --git a/deploy/python/predict_system.py b/deploy/python/predict_system.py index a93d5f06a..239875535 100644 --- a/deploy/python/predict_system.py +++ b/deploy/python/predict_system.py @@ -47,14 +47,14 @@ class SystemPredictor(object): index_dir, "vector.index")), "vector.index not found ..." assert os.path.exists(os.path.join( index_dir, "id_map.pkl")), "id_map.pkl not found ... " - - if config['IndexProcess'].get("binary_index", False): + + if config['IndexProcess'].get("dist_type") == "hamming": self.Searcher = faiss.read_index_binary( os.path.join(index_dir, "vector.index")) else: self.Searcher = faiss.read_index( os.path.join(index_dir, "vector.index")) - + with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd: self.id_map = pickle.load(fd) @@ -111,12 +111,19 @@ class SystemPredictor(object): rec_results = self.rec_predictor.predict(crop_img) preds["bbox"] = [xmin, ymin, xmax, ymax] scores, docs = self.Searcher.search(rec_results, self.return_k) - + # just top-1 result will be returned for the final - if scores[0][0] >= self.config["IndexProcess"]["score_thres"]: - preds["rec_docs"] = self.id_map[docs[0][0]].split()[1] - preds["rec_scores"] = scores[0][0] - output.append(preds) + if self.config["IndexProcess"]["dist_type"] == "hamming": + if scores[0][0] <= self.config["IndexProcess"][ + "hamming_radius"]: + preds["rec_docs"] = self.id_map[docs[0][0]].split()[1] + preds["rec_scores"] = scores[0][0] + output.append(preds) + else: + if scores[0][0] >= self.config["IndexProcess"]["score_thres"]: + preds["rec_docs"] = self.id_map[docs[0][0]].split()[1] + preds["rec_scores"] = scores[0][0] + output.append(preds) # st5: nms to the final results to avoid fetching duplicate results output = self.nms_to_rec_results( diff --git a/docs/images/wx_group.png b/docs/images/wx_group.png index b4d253f4d..4a137c802 100644 Binary files a/docs/images/wx_group.png and b/docs/images/wx_group.png differ diff --git a/docs/zh_CN/quick_start/quick_start_recognition.md b/docs/zh_CN/quick_start/quick_start_recognition.md index 5223e00c0..e2e6b169e 100644 --- a/docs/zh_CN/quick_start/quick_start_recognition.md +++ b/docs/zh_CN/quick_start/quick_start_recognition.md @@ -40,8 +40,9 @@ | 模型简介 | 推荐场景 | inference 模型 | 预测配置文件 | | ------------ | ------------- | -------- | ------- | -| 轻量级通用主体检测模型 | 通用场景 |[tar 格式文件下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer.tar) [zip 格式文件下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer.zip) | - | +| 轻量级通用主体检测模型 | 通用场景 |[tar 格式下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer.tar) [zip 格式文件下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/picodet_PPLCNet_x2_5_mainbody_lite_v1.0_infer.zip) | - | | 轻量级通用识别模型 | 通用场景 | [tar 格式下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/general_PPLCNet_x2_5_lite_v1.0_infer.tar) [zip 格式文件下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/general_PPLCNet_x2_5_lite_v1.0_infer.zip) | [inference_general.yaml](../../../deploy/configs/inference_general.yaml) | +| 轻量级通用识别二值模型 | 检索库很大, 存储受限场景 | [tar 格式下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/general_PPLCNet_x2_5_lite_binary_v1.0_infer.tar) [zip 格式文件下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/general_PPLCNet_x2_5_lite_binary_v1.0_infer.zip)| [inference_general_binary.yaml](../../../deploy/configs/inference_general_binary.yaml) | 注意:由于部分解压缩软件在解压上述 `tar` 格式文件时存在问题,建议非命令行用户下载 `zip` 格式文件并解压。`tar` 格式文件建议使用命令 `tar xf xxx.tar` 解压。 @@ -331,4 +332,3 @@ wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/recognit 按照上述步骤下载模型和测试数据后,您可以进行相关方向识别模型的测试。 * 更多关于主体检测的介绍可以参考:[主体检测教程文档](../image_recognition_pipeline/mainbody_detection.md);关于特征提取的介绍可以参考:[特征提取教程文档](../image_recognition_pipeline/feature_extraction.md);关于向量检索的介绍可以参考:[向量检索教程文档](../image_recognition_pipeline/vector_search.md)。 - diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index f5b815d16..b5afd823d 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -1,6 +1,20 @@ -from abc import ABC +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, List, Dict, Union, Callable, Any from paddle import nn -import re +from ppcls.utils import logger class Identity(nn.Layer): @@ -19,110 +33,239 @@ class TheseusLayer(nn.Layer): self.pruner = None self.quanter = None - # stop doesn't work when stop layer has a parallel branch. - def stop_after(self, stop_layer_name: str): - after_stop = False - for layer_i in self._sub_layers: - if after_stop: - self._sub_layers[layer_i] = Identity() - continue - layer_name = self._sub_layers[layer_i].full_name() - if layer_name == stop_layer_name: - after_stop = True - continue - if isinstance(self._sub_layers[layer_i], TheseusLayer): - after_stop = self._sub_layers[layer_i].stop_after( - stop_layer_name) - return after_stop - - def update_res(self, return_patterns): - for return_pattern in return_patterns: - pattern_list = return_pattern.split(".") - if not pattern_list: - continue - sub_layer_parent = self - while len(pattern_list) > 1: - if '[' in pattern_list[0]: - sub_layer_name = pattern_list[0].split('[')[0] - sub_layer_index = pattern_list[0].split('[')[1].split(']')[0] - sub_layer_parent = getattr(sub_layer_parent, sub_layer_name)[sub_layer_index] - else: - sub_layer_parent = getattr(sub_layer_parent, pattern_list[0], - None) - if sub_layer_parent is None: - break - if isinstance(sub_layer_parent, WrapLayer): - sub_layer_parent = sub_layer_parent.sub_layer - pattern_list = pattern_list[1:] - if sub_layer_parent is None: - continue - if '[' in pattern_list[0]: - sub_layer_name = pattern_list[0].split('[')[0] - sub_layer_index = pattern_list[0].split('[')[1].split(']')[0] - sub_layer = getattr(sub_layer_parent, sub_layer_name)[sub_layer_index] - if not isinstance(sub_layer, TheseusLayer): - sub_layer = wrap_theseus(sub_layer) - getattr(sub_layer_parent, sub_layer_name)[sub_layer_index] = sub_layer - else: - sub_layer = getattr(sub_layer_parent, pattern_list[0]) - if not isinstance(sub_layer, TheseusLayer): - sub_layer = wrap_theseus(sub_layer) - setattr(sub_layer_parent, pattern_list[0], sub_layer) - - sub_layer.res_dict = self.res_dict - sub_layer.res_name = return_pattern - sub_layer.register_forward_post_hook(sub_layer._save_sub_res_hook) + def _return_dict_hook(self, layer, input, output): + res_dict = {"output": output} + # 'list' is needed to avoid error raised by popping self.res_dict + for res_key in list(self.res_dict): + # clear the res_dict because the forward process may change according to input + res_dict[res_key] = self.res_dict.pop(res_key) + return res_dict def _save_sub_res_hook(self, layer, input, output): self.res_dict[self.res_name] = output - def _return_dict_hook(self, layer, input, output): - res_dict = {"output": output} - for res_key in list(self.res_dict): - res_dict[res_key] = self.res_dict.pop(res_key) + def replace_sub(self, *args, **kwargs) -> None: + msg = "The function 'replace_sub()' is deprecated, please use 'upgrade_sublayer()' instead." + logger.error(DeprecationWarning(msg)) + raise DeprecationWarning(msg) + + def upgrade_sublayer(self, + layer_name_pattern: Union[str, List[str]], + handle_func: Callable[[nn.Layer, str], nn.Layer] + ) -> Dict[str, nn.Layer]: + """use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'. + + Args: + layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'. + handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'. The formal params are the layer(nn.Layer) and pattern(str) that is (a member of) layer_name_pattern (when layer_name_pattern is List type). And the return is the layer processed. + + Returns: + Dict[str, nn.Layer]: The key is the pattern and corresponding value is the result returned by 'handle_func()'. + + Examples: + + from paddle import nn + import paddleclas + + def rep_func(layer: nn.Layer, pattern: str): + new_layer = nn.Conv2D( + in_channels=layer._in_channels, + out_channels=layer._out_channels, + kernel_size=5, + padding=2 + ) + return new_layer + + net = paddleclas.MobileNetV1() + res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func) + print(res) + # {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer} + """ + + if not isinstance(layer_name_pattern, list): + layer_name_pattern = [layer_name_pattern] + + handle_res_dict = {} + for pattern in layer_name_pattern: + # parse pattern to find target layer and its parent + layer_list = parse_pattern_str(pattern=pattern, parent_layer=self) + if not layer_list: + continue + sub_layer_parent = layer_list[-2]["layer"] if len( + layer_list) > 1 else self + + sub_layer = layer_list[-1]["layer"] + sub_layer_name = layer_list[-1]["name"] + sub_layer_index = layer_list[-1]["index"] + + new_sub_layer = handle_func(sub_layer, pattern) + + if sub_layer_index: + getattr(sub_layer_parent, + sub_layer_name)[sub_layer_index] = new_sub_layer + else: + setattr(sub_layer_parent, sub_layer_name, new_sub_layer) + + handle_res_dict[pattern] = new_sub_layer + return handle_res_dict + + def stop_after(self, stop_layer_name: str) -> bool: + """stop forward and backward after 'stop_layer_name'. + + Args: + stop_layer_name (str): The name of layer that stop forward and backward after this layer. + + Returns: + bool: 'True' if successful, 'False' otherwise. + """ + + layer_list = parse_pattern_str(stop_layer_name, self) + if not layer_list: + return False + + parent_layer = self + for layer_dict in layer_list: + name, index = layer_dict["name"], layer_dict["index"] + if not set_identity(parent_layer, name, index): + msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'." + logger.warning(msg) + return False + parent_layer = layer_dict["layer"] + + return True + + def update_res( + self, + return_patterns: Union[str, List[str]]) -> Dict[str, nn.Layer]: + """update the result(s) to be returned. + + Args: + return_patterns (Union[str, List[str]]): The name of layer to return output. + + Returns: + Dict[str, nn.Layer]: The pattern(str) and corresponding layer(nn.Layer) that have been set successfully. + """ + + # clear res_dict that could have been set + self.res_dict = {} + + class Handler(object): + def __init__(self, res_dict): + # res_dict is a reference + self.res_dict = res_dict + + def __call__(self, layer, pattern): + layer.res_dict = self.res_dict + layer.res_name = pattern + if hasattr(layer, "hook_remove_helper"): + layer.hook_remove_helper.remove() + layer.hook_remove_helper = layer.register_forward_post_hook( + layer._save_sub_res_hook) + return layer + + handle_func = Handler(self.res_dict) + + res_dict = self.upgrade_sublayer( + return_patterns, handle_func=handle_func) + + if hasattr(self, "hook_remove_helper"): + self.hook_remove_helper.remove() + self.hook_remove_helper = self.register_forward_post_hook( + self._return_dict_hook) + return res_dict - def replace_sub(self, layer_name_pattern, replace_function, - recursive=True): - for layer_i in self._sub_layers: - layer_name = self._sub_layers[layer_i].full_name() - if re.match(layer_name_pattern, layer_name): - self._sub_layers[layer_i] = replace_function(self._sub_layers[ - layer_i]) - if recursive: - if isinstance(self._sub_layers[layer_i], TheseusLayer): - self._sub_layers[layer_i].replace_sub( - layer_name_pattern, replace_function, recursive) - elif isinstance(self._sub_layers[layer_i], - (nn.Sequential, nn.LayerList)): - for layer_j in self._sub_layers[layer_i]._sub_layers: - self._sub_layers[layer_i]._sub_layers[ - layer_j].replace_sub(layer_name_pattern, - replace_function, recursive) - ''' - example of replace function: - def replace_conv(origin_conv: nn.Conv2D): - new_conv = nn.Conv2D( - in_channels=origin_conv._in_channels, - out_channels=origin_conv._out_channels, - kernel_size=origin_conv._kernel_size, - stride=2 - ) - return new_conv +def set_identity(parent_layer: nn.Layer, + layer_name: str, + layer_index: str=None) -> bool: + """set the layer specified by layer_name and layer_index to Indentity. - ''' + Args: + parent_layer (nn.Layer): The parent layer of target layer specified by layer_name and layer_index. + layer_name (str): The name of target layer to be set to Indentity. + layer_index (str, optional): The index of target layer to be set to Indentity in parent_layer. Defaults to None. + + Returns: + bool: True if successfully, False otherwise. + """ + + stop_after = False + for sub_layer_name in parent_layer._sub_layers: + if stop_after: + parent_layer._sub_layers[sub_layer_name] = Identity() + continue + if sub_layer_name == layer_name: + stop_after = True + + if layer_index and stop_after: + stop_after = False + for sub_layer_index in parent_layer._sub_layers[ + layer_name]._sub_layers: + if stop_after: + parent_layer._sub_layers[layer_name][ + sub_layer_index] = Identity() + continue + if layer_index == sub_layer_index: + stop_after = True + + return stop_after -class WrapLayer(TheseusLayer): - def __init__(self, sub_layer): - super(WrapLayer, self).__init__() - self.sub_layer = sub_layer +def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[ + None, List[Dict[str, Union[nn.Layer, str, None]]]]: + """parse the string type pattern. - def forward(self, *inputs, **kwargs): - return self.sub_layer(*inputs, **kwargs) + Args: + pattern (str): The pattern to discribe layer. + parent_layer (nn.Layer): The root layer relative to the pattern. + Returns: + Union[None, List[Dict[str, Union[nn.Layer, str, None]]]]: None if failed. If successfully, the members are layers parsed in order: + [ + {"layer": first layer, "name": first layer's name parsed, "index": first layer's index parsed if exist}, + {"layer": second layer, "name": second layer's name parsed, "index": second layer's index parsed if exist}, + ... + ] + """ -def wrap_theseus(sub_layer): - wrapped_layer = WrapLayer(sub_layer) - return wrapped_layer + pattern_list = pattern.split(".") + if not pattern_list: + msg = f"The pattern('{pattern}') is illegal. Please check and retry." + logger.warning(msg) + return None + + layer_list = [] + while len(pattern_list) > 0: + if '[' in pattern_list[0]: + target_layer_name = pattern_list[0].split('[')[0] + target_layer_index = pattern_list[0].split('[')[1].split(']')[0] + else: + target_layer_name = pattern_list[0] + target_layer_index = None + + target_layer = getattr(parent_layer, target_layer_name, None) + + if target_layer is None: + msg = f"Not found layer named('{target_layer_name}') specifed in pattern('{pattern}')." + logger.warning(msg) + return None + + if target_layer_index and target_layer: + if int(target_layer_index) < 0 or int(target_layer_index) >= len( + target_layer): + msg = f"Not found layer by index('{target_layer_index}') specifed in pattern('{pattern}'). The index should < {len(target_layer)} and > 0." + logger.warning(msg) + return None + + target_layer = target_layer[target_layer_index] + + layer_list.append({ + "layer": target_layer, + "name": target_layer_name, + "index": target_layer_index + }) + + pattern_list = pattern_list[1:] + parent_layer = target_layer + return layer_list diff --git a/ppcls/arch/backbone/legendary_models/esnet.py b/ppcls/arch/backbone/legendary_models/esnet.py index cf9c9626e..3a8d66903 100644 --- a/ppcls/arch/backbone/legendary_models/esnet.py +++ b/ppcls/arch/backbone/legendary_models/esnet.py @@ -217,7 +217,8 @@ class ESNet(TheseusLayer): class_num=1000, scale=1.0, dropout_prob=0.2, - class_expand=1280): + class_expand=1280, + return_patterns=None): super().__init__() self.scale = scale self.class_num = class_num @@ -268,6 +269,9 @@ class ESNet(TheseusLayer): self.flatten = nn.Flatten(start_axis=1, stop_axis=-1) self.fc = Linear(self.class_expand, self.class_num) + if return_patterns is not None: + self.update_res(return_patterns) + def forward(self, x): x = self.conv1(x) x = self.max_pool(x) diff --git a/ppcls/arch/backbone/legendary_models/hrnet.py b/ppcls/arch/backbone/legendary_models/hrnet.py index 7c4898a13..da6c5f676 100644 --- a/ppcls/arch/backbone/legendary_models/hrnet.py +++ b/ppcls/arch/backbone/legendary_models/hrnet.py @@ -244,7 +244,7 @@ class HighResolutionModule(TheseusLayer): for i in range(len(num_filters)): self.basic_block_list.append( - nn.Sequential(*[ + nn.Sequential(* [ BasicBlock( num_channels=num_filters[i], num_filters=num_filters[i], @@ -367,7 +367,11 @@ class HRNet(TheseusLayer): model: nn.Layer. Specific HRNet model depends on args. """ - def __init__(self, width=18, has_se=False, class_num=1000, return_patterns=None): + def __init__(self, + width=18, + has_se=False, + class_num=1000, + return_patterns=None): super().__init__() self.width = width @@ -394,7 +398,7 @@ class HRNet(TheseusLayer): stride=2, act="relu") - self.layer1 = nn.Sequential(*[ + self.layer1 = nn.Sequential(* [ BottleneckBlock( num_channels=64 if i == 0 else 256, num_filters=64, @@ -458,7 +462,6 @@ class HRNet(TheseusLayer): weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv))) if return_patterns is not None: self.update_res(return_patterns) - self.register_forward_post_hook(self._return_dict_hook) def forward(self, x): x = self.conv_layer1_1(x) diff --git a/ppcls/arch/backbone/legendary_models/inception_v3.py b/ppcls/arch/backbone/legendary_models/inception_v3.py index 50fbcb4cb..c5ccc3dc9 100644 --- a/ppcls/arch/backbone/legendary_models/inception_v3.py +++ b/ppcls/arch/backbone/legendary_models/inception_v3.py @@ -498,7 +498,6 @@ class Inception_V3(TheseusLayer): bias_attr=ParamAttr()) if return_patterns is not None: self.update_res(return_patterns) - self.register_forward_post_hook(self._return_dict_hook) def forward(self, x): x = self.inception_stem(x) diff --git a/ppcls/arch/backbone/legendary_models/mobilenet_v1.py b/ppcls/arch/backbone/legendary_models/mobilenet_v1.py index 944bdb146..8bda78d5c 100644 --- a/ppcls/arch/backbone/legendary_models/mobilenet_v1.py +++ b/ppcls/arch/backbone/legendary_models/mobilenet_v1.py @@ -128,7 +128,7 @@ class MobileNet(TheseusLayer): [int(512 * scale), 512, 1024, 512, 2], [int(1024 * scale), 1024, 1024, 1024, 1]] - self.blocks = nn.Sequential(*[ + self.blocks = nn.Sequential(* [ DepthwiseSeparable( num_channels=params[0], num_filters1=params[1], @@ -147,7 +147,6 @@ class MobileNet(TheseusLayer): weight_attr=ParamAttr(initializer=KaimingNormal())) if return_patterns is not None: self.update_res(return_patterns) - self.register_forward_post_hook(self._return_dict_hook) def forward(self, x): x = self.conv(x) diff --git a/ppcls/arch/backbone/legendary_models/mobilenet_v3.py b/ppcls/arch/backbone/legendary_models/mobilenet_v3.py index 438e48a4f..1ad42d5c5 100644 --- a/ppcls/arch/backbone/legendary_models/mobilenet_v3.py +++ b/ppcls/arch/backbone/legendary_models/mobilenet_v3.py @@ -202,7 +202,6 @@ class MobileNetV3(TheseusLayer): self.fc = Linear(self.class_expand, class_num) if return_patterns is not None: self.update_res(return_patterns) - self.register_forward_post_hook(self._return_dict_hook) def forward(self, x): x = self.conv(x) diff --git a/ppcls/arch/backbone/legendary_models/pp_lcnet.py b/ppcls/arch/backbone/legendary_models/pp_lcnet.py index 05bbccd35..327980f37 100644 --- a/ppcls/arch/backbone/legendary_models/pp_lcnet.py +++ b/ppcls/arch/backbone/legendary_models/pp_lcnet.py @@ -171,7 +171,8 @@ class PPLCNet(TheseusLayer): scale=1.0, class_num=1000, dropout_prob=0.2, - class_expand=1280): + class_expand=1280, + return_patterns=None): super().__init__() self.scale = scale self.class_expand = class_expand @@ -182,7 +183,7 @@ class PPLCNet(TheseusLayer): num_filters=make_divisible(16 * scale), stride=2) - self.blocks2 = nn.Sequential(*[ + self.blocks2 = nn.Sequential(* [ DepthwiseSeparable( num_channels=make_divisible(in_c * scale), num_filters=make_divisible(out_c * scale), @@ -192,7 +193,7 @@ class PPLCNet(TheseusLayer): for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks2"]) ]) - self.blocks3 = nn.Sequential(*[ + self.blocks3 = nn.Sequential(* [ DepthwiseSeparable( num_channels=make_divisible(in_c * scale), num_filters=make_divisible(out_c * scale), @@ -202,7 +203,7 @@ class PPLCNet(TheseusLayer): for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks3"]) ]) - self.blocks4 = nn.Sequential(*[ + self.blocks4 = nn.Sequential(* [ DepthwiseSeparable( num_channels=make_divisible(in_c * scale), num_filters=make_divisible(out_c * scale), @@ -212,7 +213,7 @@ class PPLCNet(TheseusLayer): for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks4"]) ]) - self.blocks5 = nn.Sequential(*[ + self.blocks5 = nn.Sequential(* [ DepthwiseSeparable( num_channels=make_divisible(in_c * scale), num_filters=make_divisible(out_c * scale), @@ -222,7 +223,7 @@ class PPLCNet(TheseusLayer): for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks5"]) ]) - self.blocks6 = nn.Sequential(*[ + self.blocks6 = nn.Sequential(* [ DepthwiseSeparable( num_channels=make_divisible(in_c * scale), num_filters=make_divisible(out_c * scale), @@ -248,6 +249,9 @@ class PPLCNet(TheseusLayer): self.fc = Linear(self.class_expand, class_num) + if return_patterns is not None: + self.update_res(return_patterns) + def forward(self, x): x = self.conv1(x) diff --git a/ppcls/arch/backbone/legendary_models/resnet.py b/ppcls/arch/backbone/legendary_models/resnet.py index 4f79c0d75..f37cfef9f 100644 --- a/ppcls/arch/backbone/legendary_models/resnet.py +++ b/ppcls/arch/backbone/legendary_models/resnet.py @@ -340,7 +340,6 @@ class ResNet(TheseusLayer): self.data_format = data_format if return_patterns is not None: self.update_res(return_patterns) - self.register_forward_post_hook(self._return_dict_hook) def forward(self, x): with paddle.static.amp.fp16_guard(): diff --git a/ppcls/arch/backbone/legendary_models/vgg.py b/ppcls/arch/backbone/legendary_models/vgg.py index 9b1750d54..9316e12d3 100644 --- a/ppcls/arch/backbone/legendary_models/vgg.py +++ b/ppcls/arch/backbone/legendary_models/vgg.py @@ -111,7 +111,11 @@ class VGGNet(TheseusLayer): model: nn.Layer. Specific VGG model depends on args. """ - def __init__(self, config, stop_grad_layers=0, class_num=1000, return_patterns=None): + def __init__(self, + config, + stop_grad_layers=0, + class_num=1000, + return_patterns=None): super().__init__() self.stop_grad_layers = stop_grad_layers @@ -139,7 +143,6 @@ class VGGNet(TheseusLayer): self.fc3 = Linear(4096, class_num) if return_patterns is not None: self.update_res(return_patterns) - self.register_forward_post_hook(self._return_dict_hook) def forward(self, inputs): x = self.conv_block_1(inputs) diff --git a/ppcls/arch/backbone/variant_models/pp_lcnet_variant.py b/ppcls/arch/backbone/variant_models/pp_lcnet_variant.py index 5976ab1e8..dc9747af0 100644 --- a/ppcls/arch/backbone/variant_models/pp_lcnet_variant.py +++ b/ppcls/arch/backbone/variant_models/pp_lcnet_variant.py @@ -19,11 +19,11 @@ class TanhSuffix(paddle.nn.Layer): def PPLCNet_x2_5_Tanh(pretrained=False, use_ssld=False, **kwargs): - def replace_function(origin_layer): + def replace_function(origin_layer, pattern): new_layer = TanhSuffix(origin_layer) return new_layer - match_re = "linear_0" + pattern = "fc" model = PPLCNet_x2_5(pretrained=pretrained, use_ssld=use_ssld, **kwargs) - model.replace_sub(match_re, replace_function, True) + model.upgrade_sublayer(pattern, replace_function) return model diff --git a/ppcls/arch/backbone/variant_models/resnet_variant.py b/ppcls/arch/backbone/variant_models/resnet_variant.py index 08042ad58..0219344b1 100644 --- a/ppcls/arch/backbone/variant_models/resnet_variant.py +++ b/ppcls/arch/backbone/variant_models/resnet_variant.py @@ -5,7 +5,7 @@ __all__ = ["ResNet50_last_stage_stride1"] def ResNet50_last_stage_stride1(pretrained=False, use_ssld=False, **kwargs): - def replace_function(conv): + def replace_function(conv, pattern): new_conv = Conv2D( in_channels=conv._in_channels, out_channels=conv._out_channels, @@ -16,8 +16,8 @@ def ResNet50_last_stage_stride1(pretrained=False, use_ssld=False, **kwargs): bias_attr=conv._bias_attr) return new_conv - match_re = "conv2d_4[4|6]" + pattern = ["blocks[13].conv1.conv", "blocks[13].short.conv"] model = ResNet50(pretrained=False, use_ssld=use_ssld, **kwargs) - model.replace_sub(match_re, replace_function, True) + model.upgrade_sublayer(pattern, replace_function) _load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld) return model diff --git a/ppcls/arch/backbone/variant_models/vgg_variant.py b/ppcls/arch/backbone/variant_models/vgg_variant.py index b73ad3592..c1f75ba90 100644 --- a/ppcls/arch/backbone/variant_models/vgg_variant.py +++ b/ppcls/arch/backbone/variant_models/vgg_variant.py @@ -1,28 +1,28 @@ import paddle from paddle.nn import Sigmoid from ppcls.arch.backbone.legendary_models.vgg import VGG19 - + __all__ = ["VGG19Sigmoid"] - - + + class SigmoidSuffix(paddle.nn.Layer): def __init__(self, origin_layer): - super(SigmoidSuffix, self).__init__() + super().__init__() self.origin_layer = origin_layer self.sigmoid = Sigmoid() - + def forward(self, input, res_dict=None, **kwargs): x = self.origin_layer(input) x = self.sigmoid(x) return x - - + + def VGG19Sigmoid(pretrained=False, use_ssld=False, **kwargs): - def replace_function(origin_layer): + def replace_function(origin_layer, pattern): new_layer = SigmoidSuffix(origin_layer) return new_layer - - match_re = "linear_2" + + pattern = "fc2" model = VGG19(pretrained=pretrained, use_ssld=use_ssld, **kwargs) - model.replace_sub(match_re, replace_function, True) + model.upgrade_sublayer(pattern, replace_function) return model diff --git a/ppcls/configs/Cartoonface/ResNet50_icartoon.yaml b/ppcls/configs/Cartoonface/ResNet50_icartoon.yaml index 69265ec35..3d1b99378 100644 --- a/ppcls/configs/Cartoonface/ResNet50_icartoon.yaml +++ b/ppcls/configs/Cartoonface/ResNet50_icartoon.yaml @@ -22,7 +22,7 @@ Arch: name: "ResNet50" pretrained: True BackboneStopLayer: - name: "flatten_0" + name: "flatten" output_dim: 2048 Head: name: "FC" diff --git a/ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5.yaml b/ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5.yaml index 967673f2a..626dd7c2e 100644 --- a/ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5.yaml +++ b/ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5.yaml @@ -28,7 +28,7 @@ Arch: pretrained: True use_ssld: True BackboneStopLayer: - name: flatten_0 + name: "flatten" Neck: name: FC embedding_size: 1280 diff --git a/ppcls/configs/Logo/ResNet50_ReID.yaml b/ppcls/configs/Logo/ResNet50_ReID.yaml index 48a80decd..0949add86 100644 --- a/ppcls/configs/Logo/ResNet50_ReID.yaml +++ b/ppcls/configs/Logo/ResNet50_ReID.yaml @@ -24,7 +24,7 @@ Arch: name: "ResNet50_last_stage_stride1" pretrained: True BackboneStopLayer: - name: "adaptive_avg_pool2d_0" + name: "avg_pool" Neck: name: "VehicleNeck" in_channels: 2048 diff --git a/ppcls/configs/Products/ResNet50_vd_Aliproduct.yaml b/ppcls/configs/Products/ResNet50_vd_Aliproduct.yaml index 121087041..70f805647 100644 --- a/ppcls/configs/Products/ResNet50_vd_Aliproduct.yaml +++ b/ppcls/configs/Products/ResNet50_vd_Aliproduct.yaml @@ -25,7 +25,7 @@ Arch: name: ResNet50_vd pretrained: True BackboneStopLayer: - name: flatten_0 + name: "flatten" Neck: name: FC embedding_size: 2048 diff --git a/ppcls/configs/Products/ResNet50_vd_Inshop.yaml b/ppcls/configs/Products/ResNet50_vd_Inshop.yaml index 2571ea483..18ddfa3a8 100644 --- a/ppcls/configs/Products/ResNet50_vd_Inshop.yaml +++ b/ppcls/configs/Products/ResNet50_vd_Inshop.yaml @@ -25,7 +25,7 @@ Arch: name: ResNet50_vd pretrained: False BackboneStopLayer: - name: flatten_0 + name: "flatten" Neck: name: FC embedding_size: 2048 diff --git a/ppcls/configs/Products/ResNet50_vd_SOP.yaml b/ppcls/configs/Products/ResNet50_vd_SOP.yaml index 6900181ce..7728a6678 100644 --- a/ppcls/configs/Products/ResNet50_vd_SOP.yaml +++ b/ppcls/configs/Products/ResNet50_vd_SOP.yaml @@ -22,7 +22,7 @@ Arch: name: ResNet50_vd pretrained: False BackboneStopLayer: - name: flatten_0 + name: "flatten" Neck: name: FC embedding_size: 2048 diff --git a/ppcls/configs/Vehicle/PPLCNet_2.5x_ReID.yaml b/ppcls/configs/Vehicle/PPLCNet_2.5x_ReID.yaml index 2387bff83..eb9f145a1 100644 --- a/ppcls/configs/Vehicle/PPLCNet_2.5x_ReID.yaml +++ b/ppcls/configs/Vehicle/PPLCNet_2.5x_ReID.yaml @@ -27,7 +27,7 @@ Arch: pretrained: True use_ssld: True BackboneStopLayer: - name: "flatten_0" + name: "flatten" Neck: name: "FC" embedding_size: 1280 diff --git a/ppcls/configs/Vehicle/ResNet50.yaml b/ppcls/configs/Vehicle/ResNet50.yaml index ba9008943..6b6172475 100644 --- a/ppcls/configs/Vehicle/ResNet50.yaml +++ b/ppcls/configs/Vehicle/ResNet50.yaml @@ -23,7 +23,7 @@ Arch: name: "ResNet50_last_stage_stride1" pretrained: True BackboneStopLayer: - name: "adaptive_avg_pool2d_0" + name: "avg_pool" Neck: name: "VehicleNeck" in_channels: 2048 diff --git a/ppcls/configs/Vehicle/ResNet50_ReID.yaml b/ppcls/configs/Vehicle/ResNet50_ReID.yaml index 6aebcbf0d..c13d59afd 100644 --- a/ppcls/configs/Vehicle/ResNet50_ReID.yaml +++ b/ppcls/configs/Vehicle/ResNet50_ReID.yaml @@ -24,7 +24,7 @@ Arch: name: "ResNet50_last_stage_stride1" pretrained: True BackboneStopLayer: - name: "adaptive_avg_pool2d_0" + name: "avg_pool" Neck: name: "VehicleNeck" in_channels: 2048 diff --git a/ppcls/configs/quick_start/MobileNetV1_retrieval.yaml b/ppcls/configs/quick_start/MobileNetV1_retrieval.yaml index 99f9a1250..f088e1cd9 100644 --- a/ppcls/configs/quick_start/MobileNetV1_retrieval.yaml +++ b/ppcls/configs/quick_start/MobileNetV1_retrieval.yaml @@ -25,7 +25,7 @@ Arch: name: MobileNetV1 pretrained: False BackboneStopLayer: - name: flatten_0 + name: "flatten" Neck: name: FC embedding_size: 1024 diff --git a/ppcls/configs/slim/GeneralRecognition_PPLCNet_x2_5_quantization.yaml b/ppcls/configs/slim/GeneralRecognition_PPLCNet_x2_5_quantization.yaml index 3d2b0eb5e..7b21d0ba8 100644 --- a/ppcls/configs/slim/GeneralRecognition_PPLCNet_x2_5_quantization.yaml +++ b/ppcls/configs/slim/GeneralRecognition_PPLCNet_x2_5_quantization.yaml @@ -34,7 +34,7 @@ Arch: pretrained: False use_ssld: True BackboneStopLayer: - name: flatten_0 + name: "flatten" Neck: name: FC embedding_size: 1280 diff --git a/ppcls/configs/slim/ResNet50_vehicle_cls_prune.yaml b/ppcls/configs/slim/ResNet50_vehicle_cls_prune.yaml index 4d4f08da5..1f6fea887 100644 --- a/ppcls/configs/slim/ResNet50_vehicle_cls_prune.yaml +++ b/ppcls/configs/slim/ResNet50_vehicle_cls_prune.yaml @@ -28,7 +28,7 @@ Arch: name: "ResNet50_last_stage_stride1" pretrained: True BackboneStopLayer: - name: "adaptive_avg_pool2d_0" + name: "avg_pool" Neck: name: "VehicleNeck" in_channels: 2048 diff --git a/ppcls/configs/slim/ResNet50_vehicle_cls_quantization.yaml b/ppcls/configs/slim/ResNet50_vehicle_cls_quantization.yaml index 0e45a5a9b..026b86547 100644 --- a/ppcls/configs/slim/ResNet50_vehicle_cls_quantization.yaml +++ b/ppcls/configs/slim/ResNet50_vehicle_cls_quantization.yaml @@ -27,7 +27,7 @@ Arch: name: "ResNet50_last_stage_stride1" pretrained: True BackboneStopLayer: - name: "adaptive_avg_pool2d_0" + name: "avg_pool" Neck: name: "VehicleNeck" in_channels: 2048 diff --git a/ppcls/configs/slim/ResNet50_vehicle_reid_prune.yaml b/ppcls/configs/slim/ResNet50_vehicle_reid_prune.yaml index 736c9847a..63b87f1ca 100644 --- a/ppcls/configs/slim/ResNet50_vehicle_reid_prune.yaml +++ b/ppcls/configs/slim/ResNet50_vehicle_reid_prune.yaml @@ -31,7 +31,7 @@ Arch: name: "ResNet50_last_stage_stride1" pretrained: True BackboneStopLayer: - name: "adaptive_avg_pool2d_0" + name: "avg_pool" Neck: name: "VehicleNeck" in_channels: 2048 diff --git a/ppcls/configs/slim/ResNet50_vehicle_reid_quantization.yaml b/ppcls/configs/slim/ResNet50_vehicle_reid_quantization.yaml index 72dc31865..cca9915e2 100644 --- a/ppcls/configs/slim/ResNet50_vehicle_reid_quantization.yaml +++ b/ppcls/configs/slim/ResNet50_vehicle_reid_quantization.yaml @@ -30,7 +30,7 @@ Arch: name: "ResNet50_last_stage_stride1" pretrained: True BackboneStopLayer: - name: "adaptive_avg_pool2d_0" + name: "avg_pool" Neck: name: "VehicleNeck" in_channels: 2048 diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index bb0093ad8..21897e3ad 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -161,7 +161,7 @@ class Engine(object): if metric_config is not None: metric_config = metric_config.get("Train") if metric_config is not None: - if self.train_dataloader.collate_fn: + if hasattr(self.train_dataloader, "collate_fn"): for m_idx, m in enumerate(metric_config): if "TopkAcc" in m: msg = f"'TopkAcc' metric can not be used when setting 'batch_transform_ops' in config. The 'TopkAcc' metric has been removed." @@ -312,14 +312,14 @@ class Engine(object): self.output_dir, model_name=self.config["Arch"]["name"], prefix="epoch_{}".format(epoch_id)) - # save the latest model - save_load.save_model( - self.model, - self.optimizer, {"metric": acc, - "epoch": epoch_id}, - self.output_dir, - model_name=self.config["Arch"]["name"], - prefix="latest") + # save the latest model + save_load.save_model( + self.model, + self.optimizer, {"metric": acc, + "epoch": epoch_id}, + self.output_dir, + model_name=self.config["Arch"]["name"], + prefix="latest") if self.vdl_writer is not None: self.vdl_writer.close() diff --git a/test_tipc/config/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5_train_custom_sampler.txt b/test_tipc/config/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5_train_custom_sampler.txt new file mode 100644 index 000000000..c1dbc8961 --- /dev/null +++ b/test_tipc/config/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5_train_custom_sampler.txt @@ -0,0 +1,27 @@ +===========================train_params=========================== +model_name:GeneralRecognition_PPLCNet_x2_5 +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5.yaml -o DataLoader.Train.sampler.name="PKSampler" -o DataLoader.Train.sampler.sample_per_id=2 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## + diff --git a/test_tipc/config/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5_train_multicard_eval.txt b/test_tipc/config/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5_train_multicard_eval.txt new file mode 100644 index 000000000..165cfa9fb --- /dev/null +++ b/test_tipc/config/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5_train_multicard_eval.txt @@ -0,0 +1,26 @@ +===========================train_params=========================== +model_name:GeneralRecognition_PPLCNet_x2_5 +python:python3.7 +gpu_list:0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## diff --git a/test_tipc/config/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5_train_no_eval.txt b/test_tipc/config/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5_train_no_eval.txt new file mode 100644 index 000000000..1e1675184 --- /dev/null +++ b/test_tipc/config/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5_train_no_eval.txt @@ -0,0 +1,26 @@ +===========================train_params=========================== +model_name:GeneralRecognition_PPLCNet_x2_5 +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +##