From 15a0afc67ca3e9c2ce846ab6b894552e7d11e608 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 22 Apr 2022 15:52:16 +0800 Subject: [PATCH] update code --- .../arch/backbone/legendary_models/resnet.py | 1 - .../strong_baseline_baseline.yaml | 11 +- .../{ => Pedestrian}/strong_baseline_m1.yaml | 23 ++- .../strong_baseline_m1_centerloss.yaml | 62 +++--- .../ResNet50_strong_baseline_market1501.yaml | 177 ------------------ ppcls/engine/evaluation/retrieval.py | 4 +- 6 files changed, 54 insertions(+), 224 deletions(-) rename ppcls/configs/{ => Pedestrian}/strong_baseline_baseline.yaml (96%) rename ppcls/configs/{ => Pedestrian}/strong_baseline_m1.yaml (92%) rename ppcls/configs/{ => Pedestrian}/strong_baseline_m1_centerloss.yaml (80%) delete mode 100644 ppcls/configs/PersonReID/ResNet50_strong_baseline_market1501.yaml diff --git a/ppcls/arch/backbone/legendary_models/resnet.py b/ppcls/arch/backbone/legendary_models/resnet.py index e371617a1..894366ae1 100644 --- a/ppcls/arch/backbone/legendary_models/resnet.py +++ b/ppcls/arch/backbone/legendary_models/resnet.py @@ -281,7 +281,6 @@ class ResNet(TheseusLayer): lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0], data_format="NCHW", input_image_channel=3, - stem_act="relu", return_patterns=None, return_stages=None): super().__init__() diff --git a/ppcls/configs/strong_baseline_baseline.yaml b/ppcls/configs/Pedestrian/strong_baseline_baseline.yaml similarity index 96% rename from ppcls/configs/strong_baseline_baseline.yaml rename to ppcls/configs/Pedestrian/strong_baseline_baseline.yaml index 4b1dd2fc0..c7fb30702 100644 --- a/ppcls/configs/strong_baseline_baseline.yaml +++ b/ppcls/configs/Pedestrian/strong_baseline_baseline.yaml @@ -6,14 +6,14 @@ Global: # pretrained_model: "./pd_model_trace/ISE/ISE_MS_model" # pretrained ISE model for MSMT17 output_dir: "./output/" device: "gpu" - save_interval: 10 + save_interval: 40 eval_during_train: True eval_interval: 10 epochs: 120 - print_batch_step: 10 + print_batch_step: 20 use_visualdl: False # used for static mode and model export - image_shape: [3, 128, 256] + image_shape: [3, 256, 128] save_inference_dir: "./inference" eval_mode: "retrieval" @@ -22,14 +22,14 @@ Arch: name: "RecModel" infer_output_key: "features" infer_add_softmax: False - Backbone: + Backbone: name: "ResNet50_last_stage_stride1" pretrained: True stem_act: null BackboneStopLayer: name: "flatten" Head: - name: "FC" + name: "FC" embedding_size: 2048 class_num: 751 # loss function config for traing/eval process @@ -138,4 +138,3 @@ Metric: - Recallk: topk: [1, 5] - mAP: {} - diff --git a/ppcls/configs/strong_baseline_m1.yaml b/ppcls/configs/Pedestrian/strong_baseline_m1.yaml similarity index 92% rename from ppcls/configs/strong_baseline_m1.yaml rename to ppcls/configs/Pedestrian/strong_baseline_m1.yaml index 146bed725..1343028ae 100644 --- a/ppcls/configs/strong_baseline_m1.yaml +++ b/ppcls/configs/Pedestrian/strong_baseline_m1.yaml @@ -4,17 +4,19 @@ Global: pretrained_model: null output_dir: "./output/" device: "gpu" - save_interval: 10 + save_interval: 40 eval_during_train: True eval_interval: 10 epochs: 120 print_batch_step: 20 use_visualdl: False + warmup_by_epoch: True + eval_mode: "retrieval" + re_ranking: False + feat_from: "neck" # 'backbone' or 'neck' # used for static mode and model export image_shape: [3, 256, 128] save_inference_dir: "./inference" - eval_mode: "retrieval" - feat_from: "neck" # 'backbone' or 'neck' # model architecture Arch: @@ -29,13 +31,17 @@ Arch: name: "flatten" Neck: name: BNNeck - num_filters: 2048 - # trainable: False # TODO: free bn.bias + num_features: &feat_dim 2048 + # trainable: False # TODO: freeze bn.bias Head: name: "FC" - embedding_size: 2048 - class_num: 751 - bias_attr: false + embedding_size: *feat_dim + class_num: &class_num 751 + weight_attr: + initializer: + name: Normal + std: 0.001 + bias_attr: False # loss function config for traing/eval process Loss: @@ -160,4 +166,3 @@ Metric: - Recallk: topk: [1, 5] - mAP: {} - diff --git a/ppcls/configs/strong_baseline_m1_centerloss.yaml b/ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml similarity index 80% rename from ppcls/configs/strong_baseline_m1_centerloss.yaml rename to ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml index b8435c654..057117cc7 100644 --- a/ppcls/configs/strong_baseline_m1_centerloss.yaml +++ b/ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml @@ -4,17 +4,19 @@ Global: pretrained_model: null output_dir: "./output/" device: "gpu" - save_interval: 10 + save_interval: 40 eval_during_train: True eval_interval: 10 epochs: 120 print_batch_step: 20 use_visualdl: False + warmup_by_epoch: True + eval_mode: "retrieval" + re_ranking: False + feat_from: "neck" # 'backbone' or 'neck' # used for static mode and model export image_shape: [3, 256, 128] save_inference_dir: "./inference" - eval_mode: "retrieval" - feat_from: "neck" # 'backbone' or 'neck' # model architecture Arch: @@ -29,13 +31,17 @@ Arch: name: "flatten" Neck: name: BNNeck - num_filters: 2048 - # trainable: False # TODO: free bn.bias + num_features: &feat_dim 2048 + # trainable: False # TODO: freeze bn.bias Head: name: "FC" - embedding_size: 2048 - class_num: 751 - bias_attr: false + embedding_size: *feat_dim + class_num: &class_num 751 + weight_attr: + initializer: + name: Normal + std: 0.001 + bias_attr: False # loss function config for traing/eval process Loss: @@ -43,35 +49,36 @@ Loss: - CELoss: weight: 1.0 epsilon: 0.1 - - TripletLossV2: + - TripletLossV3: weight: 1.0 margin: 0.3 normalize_feature: false - CenterLoss: weight: 0.0005 - num_classes: 751 - feat_dim: 2048 + num_classes: *class_num + feat_dim: *feat_dim Eval: - CELoss: weight: 1.0 Optimizer: - model: - name: Adam - lr: - name: Piecewise - decay_epochs: [30, 60] - values: [0.00035, 0.000035, 0.0000035] - warmup_epoch: 10 - warmup_start_lr: 0.0000035 - regularizer: - name: 'L2' - coeff: 0.0005 - loss: - name: SGD - lr: - name: Constant - learning_rate: 0.5 + - Adam: + scope: model + lr: + name: Piecewise + decay_epochs: [30, 60] + values: [0.00035, 0.000035, 0.0000035] + warmup_epoch: 10 + warmup_start_lr: 0.0000035 + warmup_by_epoch: True + regularizer: + name: 'L2' + coeff: 0.0005 + - SGD: + scope: CenterLoss + lr: + name: Constant + learning_rate: 1000.0 # set to ori_lr*(1/centerloss_weight) to void manually scaling centers' gradidents. # data loader for train and eval DataLoader: @@ -170,4 +177,3 @@ Metric: - Recallk: topk: [1, 5] - mAP: {} - diff --git a/ppcls/configs/PersonReID/ResNet50_strong_baseline_market1501.yaml b/ppcls/configs/PersonReID/ResNet50_strong_baseline_market1501.yaml deleted file mode 100644 index f63f58a8a..000000000 --- a/ppcls/configs/PersonReID/ResNet50_strong_baseline_market1501.yaml +++ /dev/null @@ -1,177 +0,0 @@ -# global configs -Global: - checkpoints: null - pretrained_model: null - output_dir: "./output/" - device: "gpu" - save_interval: 40 - eval_during_train: True - eval_interval: 10 - epochs: 120 - print_batch_step: 20 - use_visualdl: False - warmup_by_epoch: True - eval_mode: "retrieval" - re_ranking: True - # used for static mode and model export - image_shape: [3, 256, 128] - save_inference_dir: "./inference" - -# model architecture -Arch: - name: "RecModel" - infer_output_key: "features" - infer_add_softmax: False - Backbone: - name: "ResNet50_last_stage_stride1" - pretrained: True - stem_act: null - BackboneStopLayer: - name: "flatten" - Neck: - name: BNNeck - num_features: &feat_dim 2048 - Head: - name: "FC" - embedding_size: *feat_dim - class_num: &class_num 751 - weight_attr: - initializer: - name: Normal - std: 0.001 - bias_attr: False - -# loss function config for traing/eval process -Loss: - Train: - - CELoss: - weight: 1.0 - epsilon: 0.1 - - TripletLossV3: - weight: 1.0 - margin: 0.3 - normalize_feature: false - - CenterLoss: - weight: 0.0005 - num_classes: *class_num - feat_dim: *feat_dim - Eval: - - CELoss: - weight: 1.0 - -Optimizer: - - Adam: - scope: model - lr: - name: Piecewise - decay_epochs: [30, 60] - values: [0.00035, 0.000035, 0.0000035] - warmup_epoch: 10 - warmup_start_lr: 0.0000035 - warmup_by_epoch: True - regularizer: - name: 'L2' - coeff: 0.0005 - - SGD: - scope: CenterLoss - lr: - name: Constant - learning_rate: 1000.0 - -# data loader for train and eval -DataLoader: - Train: - dataset: - name: "Market1501" - image_root: "./dataset/Market-1501-v15.09.15" - cls_label_path: "bounding_box_train" - transform_ops: - - DecodeImage: - to_rgb: True - channel_first: False - - ResizeImage: - size: [128, 256] - - RandFlipImage: - flip_code: 1 - - Pad: - padding: 10 - - RandCropImage: - size: [128, 256] - scale: [ 0.8022, 0.8022 ] - ratio: [ 0.5, 0.5 ] - - NormalizeImage: - scale: 0.00392157 - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - order: '' - - RandomErasing: - EPSILON: 0.5 - sl: 0.02 - sh: 0.4 - r1: 0.3 - mean: [0.4914, 0.4822, 0.4465] - sampler: - name: DistributedRandomIdentitySampler - batch_size: 64 - num_instances: 4 - drop_last: True - shuffle: True - loader: - num_workers: 4 - use_shared_memory: True - Eval: - Query: - dataset: - name: "Market1501" - image_root: "./dataset/Market-1501-v15.09.15" - cls_label_path: "query" - transform_ops: - - DecodeImage: - to_rgb: True - channel_first: False - - ResizeImage: - size: [128, 256] - - NormalizeImage: - scale: 0.00392157 - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - order: '' - sampler: - name: DistributedBatchSampler - batch_size: 128 - drop_last: False - shuffle: False - loader: - num_workers: 4 - use_shared_memory: True - - Gallery: - dataset: - name: "Market1501" - image_root: "./dataset/Market-1501-v15.09.15" - cls_label_path: "bounding_box_test" - transform_ops: - - DecodeImage: - to_rgb: True - channel_first: False - - ResizeImage: - size: [128, 256] - - NormalizeImage: - scale: 0.00392157 - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - order: '' - sampler: - name: DistributedBatchSampler - batch_size: 128 - drop_last: False - shuffle: False - loader: - num_workers: 4 - use_shared_memory: True - -Metric: - Eval: - - Recallk: - topk: [1, 5] - - mAP: {} diff --git a/ppcls/engine/evaluation/retrieval.py b/ppcls/engine/evaluation/retrieval.py index 52f67dfed..408c5d201 100644 --- a/ppcls/engine/evaluation/retrieval.py +++ b/ppcls/engine/evaluation/retrieval.py @@ -82,6 +82,7 @@ def retrieval_eval(engine, epoch_id=0): metric_dict[key] += metric_tmp[key] * block_fea.shape[ 0] / len(query_feas) else: + metric_dict = dict() distmat = re_ranking( query_feas, gallery_feas, k1=20, k2=6, lambda_value=0.3) cmc, mAP = eval_func(distmat, @@ -93,9 +94,6 @@ def retrieval_eval(engine, epoch_id=0): metric_dict["recall5(RK)"] = cmc[4] metric_dict["mAP(RK)"] = mAP - for key in metric_tmp: - metric_dict[key] = metric_tmp[key] * block_fea.shape[0] / len( - query_feas) metric_info_list = [] for key in metric_dict: if metric_key is None: