diff --git a/ppcls/arch/gears/fc.py b/ppcls/arch/gears/fc.py index b32474195..cf2da639b 100644 --- a/ppcls/arch/gears/fc.py +++ b/ppcls/arch/gears/fc.py @@ -19,16 +19,29 @@ from __future__ import print_function import paddle import paddle.nn as nn +from ppcls.arch.utils import get_param_attr_dict + class FC(nn.Layer): - def __init__(self, embedding_size, class_num): + def __init__(self, embedding_size, class_num, **kwargs): super(FC, self).__init__() self.embedding_size = embedding_size self.class_num = class_num + weight_attr = paddle.ParamAttr( initializer=paddle.nn.initializer.XavierNormal()) - self.fc = paddle.nn.Linear( - self.embedding_size, self.class_num, weight_attr=weight_attr) + if 'weight_attr' in kwargs: + weight_attr = get_param_attr_dict(kwargs['weight_attr'], None) + + bias_attr = None + if 'bias_attr' in kwargs: + bias_attr = get_param_attr_dict(kwargs['bias_attr'], None) + + self.fc = nn.Linear( + self.embedding_size, + self.class_num, + weight_attr=weight_attr, + bias_attr=bias_attr) def forward(self, input, label=None): out = self.fc(input) diff --git a/ppcls/arch/utils.py b/ppcls/arch/utils.py index 308475d7d..0f439eda1 100644 --- a/ppcls/arch/utils.py +++ b/ppcls/arch/utils.py @@ -14,9 +14,11 @@ import six import types +import paddle from difflib import SequenceMatcher from . import backbone +from typing import Any, Dict, Union def get_architectures(): @@ -31,8 +33,8 @@ def get_architectures(): def get_blacklist_model_in_static_mode(): - from ppcls.arch.backbone import distilled_vision_transformer - from ppcls.arch.backbone import vision_transformer + from ppcls.arch.backbone import (distilled_vision_transformer, + vision_transformer) blacklist = distilled_vision_transformer.__all__ + vision_transformer.__all__ return blacklist @@ -51,3 +53,47 @@ def similar_architectures(name='', names=[], thresh=0.1, topk=10): scores.sort(key=lambda x: x[1], reverse=True) similar_names = [names[s[0]] for s in scores[:min(topk, len(scores))]] return similar_names + + +def get_param_attr_dict(ParamAttr_config: Union[None, bool, Dict[str, Dict]] + ) -> Union[None, bool, paddle.ParamAttr]: + """parse ParamAttr from an dict + + Args: + ParamAttr_config (Union[bool, Dict[str, Dict]]): ParamAttr_config + + Returns: + Union[bool, paddle.ParamAttr]: Generated ParamAttr + """ + if ParamAttr_config is None: + return None + if isinstance(ParamAttr_config, bool): + return ParamAttr_config + ParamAttr_dict = {} + if 'initiliazer' in ParamAttr_config: + initiliazer_cfg = ParamAttr_config.get('initiliazer') + if 'name' in initiliazer_cfg: + initiliazer_name = initiliazer_cfg.pop('name') + ParamAttr_dict['initiliazer'] = getattr( + paddle.nn.initializer, initiliazer_name)(**initiliazer_cfg) + else: + raise ValueError(f"'name' must specified in initiliazer_cfg") + if 'learning_rate' in ParamAttr_config: + # NOTE: only support an single value now + learning_rate_value = ParamAttr_config.get('learning_rate') + if isinstance(learning_rate_value, (int, float)): + ParamAttr_dict['learning_rate'] = learning_rate_value + else: + raise ValueError( + f"learning_rate_value must be float or int, but got {type(learning_rate_value)}" + ) + if 'regularizer' in ParamAttr_config: + regularizer_cfg = ParamAttr_config.get('regularizer') + if 'name' in regularizer_cfg: + # L1Decay or L2Decay + regularizer_name = regularizer_cfg.pop('name') + ParamAttr_dict['regularizer'] = getattr( + paddle.regularizer, regularizer_name)(**regularizer_cfg) + else: + raise ValueError(f"'name' must specified in regularizer_cfg") + return paddle.ParamAttr(**ParamAttr_dict) diff --git a/ppcls/configs/PersonReID/ResNet50_strong_baseline_market1501.yaml b/ppcls/configs/PersonReID/ResNet50_strong_baseline_market1501.yaml new file mode 100644 index 000000000..d69d50c2a --- /dev/null +++ b/ppcls/configs/PersonReID/ResNet50_strong_baseline_market1501.yaml @@ -0,0 +1,178 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: "./output/" + device: "gpu" + save_interval: 40 + eval_during_train: True + eval_interval: 10 + epochs: 120 + print_batch_step: 20 + use_visualdl: False + warmup_epoch_by_epoch: True + eval_mode: "retrieval" + re_ranking: True + # used for static mode and model export + image_shape: [3, 256, 128] + save_inference_dir: "./inference" + +# model architecture +Arch: + name: "RecModel" + infer_output_key: "features" + infer_add_softmax: False + Backbone: + name: "ResNet50_last_stage_stride1" + pretrained: True + stem_act: null + BackboneStopLayer: + name: "flatten" + Neck: + name: BNNeck + num_features: &feat_dim 2048 + Head: + name: "FC" + embedding_size: *feat_dim + class_num: &class_num 751 + weight_attr: + initializer: + name: Normal + std: 0.001 + bias_attr: False + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + - TripletLossV3: + weight: 1.0 + margin: 0.3 + normalize_feature: false + - CenterLoss: + weight: 0.0005 + num_classes: *class_num + feat_dim: *feat_dim + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + - Adam: + scope: model + lr: + name: Piecewise + decay_epochs: [30, 60] + values: [0.00035, 0.000035, 0.0000035] + warmup_epoch: 10 + warmup_start_lr: 0.0000035 + warmup_epoch_by_epoch: True + regularizer: + name: 'L2' + coeff: 0.0005 + - SGD: + sope: TripletLossV3 + lr: + name: Constant + learning_rate: 0.5 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: "VeriWild" + image_root: "./dataset/market1501/bounding_box_train" + cls_label_path: "./dataset/market1501/bounding_box_train.txt" + relabel: True + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: [128, 256] + - RandFlipImage: + flip_code: 1 + - Pad: + padding: 10 + - RandCropImage: + size: [128, 256] + scale: [ 0.8022, 0.8022 ] + ratio: [ 0.5, 0.5 ] + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.5 + sl: 0.02 + sh: 0.4 + r1: 0.3 + mean: [0.4914, 0.4822, 0.4465] + sampler: + name: DistributedRandomIdentitySampler + batch_size: 64 + num_instances: 4 + drop_last: True + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + Eval: + Query: + dataset: + name: "VeriWild" + image_root: "./dataset/market1501/query" + cls_label_path: "./dataset/market1501/query.txt" + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: [128, 256] + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + + Gallery: + dataset: + name: "VeriWild" + image_root: "./dataset/market1501/bounding_box_test" + cls_label_path: "./dataset/market1501/bounding_box_test.txt" + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: [128, 256] + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Metric: + Eval: + - Recallk: + topk: [1, 5] + - mAP: {} diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 71a7e182a..b20d1f572 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -298,12 +298,24 @@ class Engine(object): self.max_iter = len(self.train_dataloader) - 1 if platform.system( ) == "Windows" else len(self.train_dataloader) + + if self.config["Global"].get("warmup_epoch_by_epoch", False): + for i in range(len(self.lr_sch)): + self.lr_sch[i].step() + logger.info( + "lr_sch step once before first epoch, when Global.warmup_epoch_by_epoch=True" + ) + for epoch_id in range(best_metric["epoch"] + 1, self.config["Global"]["epochs"] + 1): acc = 0.0 # for one epoch train self.train_epoch_func(self, epoch_id, print_batch_step) + if self.config["Global"].get("warmup_epoch_by_epoch", False): + for i in range(len(self.lr_sch)): + self.lr_sch[i].step() + if self.use_dali: self.train_dataloader.reset() metric_msg = ", ".join([ diff --git a/ppcls/engine/evaluation/retrieval.py b/ppcls/engine/evaluation/retrieval.py index 8471a42c7..9229fc071 100644 --- a/ppcls/engine/evaluation/retrieval.py +++ b/ppcls/engine/evaluation/retrieval.py @@ -16,6 +16,8 @@ from __future__ import division from __future__ import print_function import platform + +import numpy as np import paddle from ppcls.utils import logger @@ -49,34 +51,55 @@ def retrieval_eval(engine, epoch_id=0): metric_dict = {metric_key: 0.} else: metric_dict = dict() - for block_idx, block_fea in enumerate(fea_blocks): - similarity_matrix = paddle.matmul( - block_fea, gallery_feas, transpose_y=True) - if query_query_id is not None: - query_id_block = query_id_blocks[block_idx] - query_id_mask = (query_id_block != gallery_unique_id.t()) + reranking_flag = engine.config['Global'].get('re_ranking', False) + logger.info(f"re_ranking={reranking_flag}") + if not reranking_flag: + for block_idx, block_fea in enumerate(fea_blocks): + similarity_matrix = paddle.matmul( + block_fea, gallery_feas, transpose_y=True) + if query_query_id is not None: + query_id_block = query_id_blocks[block_idx] + query_id_mask = (query_id_block != gallery_unique_id.t()) - image_id_block = image_id_blocks[block_idx] - image_id_mask = (image_id_block != gallery_img_id.t()) + image_id_block = image_id_blocks[block_idx] + image_id_mask = (image_id_block != gallery_img_id.t()) - keep_mask = paddle.logical_or(query_id_mask, image_id_mask) - similarity_matrix = similarity_matrix * keep_mask.astype( - "float32") - else: - keep_mask = None - - metric_tmp = engine.eval_metric_func(similarity_matrix, - image_id_blocks[block_idx], - gallery_img_id, keep_mask) - - for key in metric_tmp: - if key not in metric_dict: - metric_dict[key] = metric_tmp[key] * block_fea.shape[ - 0] / len(query_feas) + keep_mask = paddle.logical_or(query_id_mask, image_id_mask) + similarity_matrix = similarity_matrix * keep_mask.astype( + "float32") else: - metric_dict[key] += metric_tmp[key] * block_fea.shape[ - 0] / len(query_feas) + keep_mask = None + metric_tmp = engine.eval_metric_func( + similarity_matrix, image_id_blocks[block_idx], + gallery_img_id, keep_mask) + + for key in metric_tmp: + if key not in metric_dict: + metric_dict[key] = metric_tmp[key] * block_fea.shape[ + 0] / len(query_feas) + else: + metric_dict[key] += metric_tmp[key] * block_fea.shape[ + 0] / len(query_feas) + else: + distmat = re_ranking( + query_feas, + gallery_feas, + query_img_id, + query_query_id, + gallery_img_id, + gallery_unique_id, + k1=20, + k2=6, + lambda_value=0.3) + cmc, mAP = eval_func(distmat, + np.squeeze(query_img_id.numpy()), + np.squeeze(gallery_img_id.numpy()), + np.squeeze(query_query_id.numpy()), + np.squeeze(gallery_unique_id.numpy())) + for key in metric_tmp: + metric_dict[key] = metric_tmp[key] * block_fea.shape[0] / len( + query_feas) metric_info_list = [] for key in metric_dict: if metric_key is None: @@ -88,6 +111,162 @@ def retrieval_eval(engine, epoch_id=0): return metric_dict[metric_key] +def re_ranking(queFea, + galFea, + k1=20, + k2=6, + lambda_value=0.5, + local_distmat=None, + only_local=False): + # if feature vector is numpy, you should use 'paddle.tensor' transform it to tensor + query_num = queFea.shape[0] + all_num = query_num + galFea.shape[0] + if only_local: + original_dist = local_distmat + else: + feat = paddle.concat([queFea, galFea]) + logger.info('using GPU to compute original distance') + + # L2 distance + distmat = paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([all_num, all_num]) + \ + paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([all_num, all_num]).t() + distmat = distmat.addmm(x=feat, y=feat.t(), alpha=-2.0, beta=1.0) + # Cosine distance + # distmat = paddle.matmul(queFea, galFea, transpose_y=True) + # if query_query_id is not None: + # query_id_mask = (queCid != galCid.t()) + # image_id_mask = (queId != galId.t()) + # keep_mask = paddle.logical_or(query_id_mask, image_id_mask) + # distmat = distmat * keep_mask.astype("float32") + + original_dist = distmat.cpu().numpy() + del feat + if local_distmat is not None: + original_dist = original_dist + local_distmat + + gallery_num = original_dist.shape[0] + original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) + V = np.zeros_like(original_dist).astype(np.float16) + initial_rank = np.argsort(original_dist).astype(np.int32) + logger.info('starting re_ranking') + for i in range(all_num): + # k-reciprocal neighbors + forward_k_neigh_index = initial_rank[i, :k1 + 1] + backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] + fi = np.where(backward_k_neigh_index == i)[0] + k_reciprocal_index = forward_k_neigh_index[fi] + k_reciprocal_expansion_index = k_reciprocal_index + for j in range(len(k_reciprocal_index)): + candidate = k_reciprocal_index[j] + candidate_forward_k_neigh_index = initial_rank[candidate, :int( + np.around(k1 / 2)) + 1] + candidate_backward_k_neigh_index = initial_rank[ + candidate_forward_k_neigh_index, :int(np.around(k1 / 2)) + 1] + fi_candidate = np.where( + candidate_backward_k_neigh_index == candidate)[0] + candidate_k_reciprocal_index = candidate_forward_k_neigh_index[ + fi_candidate] + if len( + np.intersect1d(candidate_k_reciprocal_index, + k_reciprocal_index)) > 2 / 3 * len( + candidate_k_reciprocal_index): + k_reciprocal_expansion_index = np.append( + k_reciprocal_expansion_index, candidate_k_reciprocal_index) + + k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) + weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) + V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) + all_num_cost = time.time() - t + original_dist = original_dist[:query_num, ] + if k2 != 1: + V_qe = np.zeros_like(V, dtype=np.float16) + for i in range(all_num): + V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) + V = V_qe + del V_qe + del initial_rank + invIndex = [] + for i in range(gallery_num): + invIndex.append(np.where(V[:, i] != 0)[0]) + + jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) + gallery_num_cost = time.time() - t + for i in range(query_num): + temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) + indNonZero = np.where(V[i, :] != 0)[0] + indImages = [invIndex[ind] for ind in indNonZero] + for j in range(len(indNonZero)): + temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum( + V[i, indNonZero[j]], V[indImages[j], indNonZero[j]]) + jaccard_dist[i] = 1 - temp_min / (2 - temp_min) + + final_dist = jaccard_dist * (1 - lambda_value + ) + original_dist * lambda_value + del original_dist + del V + del jaccard_dist + final_dist = final_dist[:query_num, query_num:] + query_num_cost = time.time() - t + return final_dist + + +def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): + """Evaluation with market1501 metric + Key: for each query identity, its gallery images from the same camera view are discarded. + """ + num_q, num_g = distmat.shape + if num_g < max_rank: + max_rank = num_g + print("Note: number of gallery samples is quite small, got {}".format( + num_g)) + indices = np.argsort(distmat, axis=1) + matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) + + # compute cmc curve for each query + all_cmc = [] + all_AP = [] + num_valid_q = 0. # number of valid query + for q_idx in range(num_q): + # get query pid and camid + q_pid = q_pids[q_idx] + q_camid = q_camids[q_idx] + + # remove gallery samples that have the same pid and camid with query + order = indices[q_idx] + remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) + keep = np.invert(remove) + + # compute cmc curve + # binary vector, positions with value 1 are correct matches + orig_cmc = matches[q_idx][keep] + if not np.any(orig_cmc): + # this condition is true when query identity does not appear in gallery + continue + + cmc = orig_cmc.cumsum() + cmc[cmc > 1] = 1 + + all_cmc.append(cmc[:max_rank]) + num_valid_q += 1. + + # compute average precision + # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision + num_rel = orig_cmc.sum() + tmp_cmc = orig_cmc.cumsum() + tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] + tmp_cmc = np.asarray(tmp_cmc) * orig_cmc + AP = tmp_cmc.sum() / num_rel + all_AP.append(AP) + + assert num_valid_q > 0, "Error: all query identities do not appear in gallery" + + all_cmc = np.asarray(all_cmc).astype(np.float32) + all_cmc = all_cmc.sum(0) / num_valid_q + mAP = np.mean(all_AP) + + return all_cmc, mAP + + def cal_feature(engine, name='gallery'): all_feas = None all_image_id = None diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index 1e944a609..be0561bd9 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -63,9 +63,27 @@ def train_epoch(engine, epoch_id, print_batch_step): loss_dict["loss"].backward() for i in range(len(engine.optimizer)): engine.optimizer[i].step() + + if hasattr(engine.model.neck, 'bn'): + engine.model.neck.bn.bias.grad.set_value( + paddle.zeros_like(engine.model.neck.bn.bias.grad)) + # clear grad for i in range(len(engine.optimizer)): + # manually scale up grad of center_loss + if i == 1: + for j in range(len(engine.train_loss_func.loss_func)): + if len(engine.train_loss_func.loss_func[j].parameters( + )) == 0: + continue + for param in engine.train_loss_func.loss_func[ + j].parameters(): + if hasattr(param, 'grad') and param.grad is not None: + param.grad.set_value(param.grad * ( + 1.0 / engine.train_loss_func.loss_weight[j])) + engine.optimizer[i].clear_grad() + # step lr for i in range(len(engine.lr_sch)): engine.lr_sch[i].step() diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index c3281b0e5..7881ba607 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -11,7 +11,7 @@ from .emlloss import EmlLoss from .msmloss import MSMLoss from .npairsloss import NpairsLoss from .trihardloss import TriHardLoss -from .triplet import TripletLoss, TripletLossV2 +from .triplet import TripletLoss, TripletLossV2, TripletLossV3 from .supconloss import SupConLoss from .pairwisecosface import PairwiseCosface from .dmlloss import DMLLoss diff --git a/ppcls/loss/centerloss.py b/ppcls/loss/centerloss.py index d85b3f2a9..488c514a8 100644 --- a/ppcls/loss/centerloss.py +++ b/ppcls/loss/centerloss.py @@ -1,54 +1,74 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import absolute_import from __future__ import division from __future__ import print_function + +from typing import Dict + import paddle import paddle.nn as nn -import paddle.nn.functional as F class CenterLoss(nn.Layer): - def __init__(self, num_classes=5013, feat_dim=2048): + """Center loss class + + Args: + num_classes (int): number of classes. + feat_dim (int): number of feature dimensions. + """ + + def __init__(self, num_classes: int, feat_dim: int): super(CenterLoss, self).__init__() self.num_classes = num_classes self.feat_dim = feat_dim - self.centers = paddle.randn( - shape=[self.num_classes, self.feat_dim]).astype( - "float64") #random center + random_init_centers = paddle.randn( + shape=[self.num_classes, self.feat_dim]) + self.centers = self.create_parameter( + shape=(self.num_classes, self.feat_dim), + default_initializer=nn.initializer.Assign(random_init_centers)) + self.add_parameter("centers", self.centers) - def __call__(self, input, target): + def __call__(self, input: Dict[str, paddle.Tensor], + target: paddle.Tensor) -> Dict[str, paddle.Tensor]: + """compute center loss. + + Args: + input (Dict[str, paddle.Tensor]): {'features': (batch_size, feature_dim), ...}. + target (paddle.Tensor): ground truth label with shape (batch_size, ). + + Returns: + Dict[str, paddle.Tensor]: {'CenterLoss': loss}. """ - inputs: network output: {"features: xxx", "logits": xxxx} - target: image label - """ - feats = input["features"] + feats = input['backbone'] labels = target + + # squeeze labels to shape (batch_size, ) + if labels.ndim >= 2 and labels.shape[-1] == 1: + labels = paddle.squeeze(labels, axis=[-1]) + batch_size = feats.shape[0] + distmat = paddle.pow(feats, 2).sum(axis=1, keepdim=True).expand([batch_size, self.num_classes]) + \ + paddle.pow(self.centers, 2).sum(axis=1, keepdim=True).expand([self.num_classes, batch_size]).t() + distmat = distmat.addmm(x=feats, y=self.centers.t(), beta=1, alpha=-2) - #calc feat * feat - dist1 = paddle.sum(paddle.square(feats), axis=1, keepdim=True) - dist1 = paddle.expand(dist1, [batch_size, self.num_classes]) - - #dist2 of centers - dist2 = paddle.sum(paddle.square(self.centers), axis=1, - keepdim=True) #num_classes - dist2 = paddle.expand(dist2, - [self.num_classes, batch_size]).astype("float64") - dist2 = paddle.transpose(dist2, [1, 0]) - - #first x * x + y * y - distmat = paddle.add(dist1, dist2) - tmp = paddle.matmul(feats, paddle.transpose(self.centers, [1, 0])) - distmat = distmat - 2.0 * tmp - - #generate the mask - classes = paddle.arange(self.num_classes).astype("int64") - labels = paddle.expand( - paddle.unsqueeze(labels, 1), (batch_size, self.num_classes)) - mask = paddle.equal( - paddle.expand(classes, [batch_size, self.num_classes]), - labels).astype("float64") #get mask - - dist = paddle.multiply(distmat, mask) - loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size + classes = paddle.arange(self.num_classes).astype(labels.dtype) + labels = labels.unsqueeze(1).expand([batch_size, self.num_classes]) + mask = labels.equal(classes.expand([batch_size, self.num_classes])) + dist = distmat * mask.astype(feats.dtype) + loss = dist.clip(min=1e-12, max=1e+12).sum() / batch_size + # return loss return {'CenterLoss': loss} diff --git a/ppcls/loss/triplet.py b/ppcls/loss/triplet.py index d1c7eec9e..0464ae377 100644 --- a/ppcls/loss/triplet.py +++ b/ppcls/loss/triplet.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from typing import Tuple import paddle import paddle.nn as nn @@ -135,3 +136,122 @@ class TripletLoss(nn.Layer): y = paddle.ones_like(dist_an) loss = self.ranking_loss(dist_an, dist_ap, y) return {"TripletLoss": loss} + + +class TripletLossV3(nn.Layer): + """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). + Related Triplet Loss theory can be found in paper 'In Defense of the Triplet + Loss for Person Re-Identification'.""" + + def __init__(self, margin=None, normalize_feature=False): + super(TripletLossV3, self).__init__() + self.normalize_feature = normalize_feature + self.margin = margin + if margin is not None: + self.ranking_loss = nn.MarginRankingLoss(margin=margin) + else: + self.ranking_loss = nn.SoftMarginLoss() + + def forward(self, input, target): + global_feat = input["backbone"] + if self.normalize_feature: + global_feat = self._normalize(global_feat, axis=-1) + dist_mat = self._euclidean_dist(global_feat, global_feat) + dist_ap, dist_an = self._hard_example_mining(dist_mat, target) + y = paddle.ones_like(dist_an) + if self.margin is not None: + loss = self.ranking_loss(dist_an, dist_ap, y) + + return {"TripletLossV3": loss} + + def _normalize(self, x: paddle.Tensor, axis: int=-1) -> paddle.Tensor: + """Normalizing to unit length along the specified dimension. + + Args: + x (paddle.Tensor): (batch_size, feature_dim) + axis (int, optional): normalization dim. Defaults to -1. + + Returns: + paddle.Tensor: (batch_size, feature_dim) + """ + x = 1. * x / (paddle.norm( + x, 2, axis, keepdim=True).expand_as(x) + 1e-12) + return x + + def _euclidean_dist(self, x: paddle.Tensor, + y: paddle.Tensor) -> paddle.Tensor: + """compute euclidean distance between two batched vectors + + Args: + x (paddle.Tensor): (N, feature_dim) + y (paddle.Tensor): (M, feature_dim) + + Returns: + paddle.Tensor: (N, M) + """ + m, n = x.shape[0], y.shape[0] + d = x.shape[1] + xx = paddle.pow(x, 2).sum(1, keepdim=True).expand([m, n]) + yy = paddle.pow(y, 2).sum(1, keepdim=True).expand([n, m]).t() + dist = xx + yy + dist = dist.addmm(x, y.t(), alpha=-2, beta=1) + # dist = dist - 2*(x@y.t()) + dist = dist.clip(min=1e-12).sqrt() # for numerical stability + return dist + + def _hard_example_mining( + self, + dist_mat: paddle.Tensor, + labels: paddle.Tensor, + return_inds: bool=False) -> Tuple[paddle.Tensor, paddle.Tensor]: + """For each anchor, find the hardest positive and negative sample. + + Args: + dist_mat (paddle.Tensor): pair wise distance between samples, [N, N] + labels (paddle.Tensor): labels, [N, ] + return_inds (bool, optional): whether to return the indices . Defaults to False. + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: [(N, ), (N, )] + + NOTE: Only consider the case in which all labels have same num of samples, + thus we can cope with all anchors in parallel. + """ + assert len(dist_mat.shape) == 2 + assert dist_mat.shape[0] == dist_mat.shape[1] + N = dist_mat.shape[0] + + # shape [N, N] + is_pos = labels.expand([N, N]).equal(labels.expand([N, N]).t()) + is_neg = labels.expand([N, N]).not_equal(labels.expand([N, N]).t()) + + # `dist_ap` means distance(anchor, positive) + # both `dist_ap` and `relative_p_inds` with shape [N, 1] + dist_ap = paddle.max(dist_mat[is_pos].reshape([N, -1]), + 1, + keepdim=True) + # `dist_an` means distance(anchor, negative) + # both `dist_an` and `relative_n_inds` with shape [N, 1] + dist_an = paddle.min(dist_mat[is_neg].reshape([N, -1]), + 1, + keepdim=True) + # shape [N] + dist_ap = dist_ap.squeeze(1) + dist_an = dist_an.squeeze(1) + + if return_inds: + # shape [N, N] + ind = (labels.new().resize_as_(labels) + .copy_(paddle.arange(0, N).long()) + .unsqueeze(0).expand(N, N)) + # shape [N, 1] + p_inds = paddle.gather(ind[is_pos].reshape([N, -1]), 1, + relative_p_inds.data) + n_inds = paddle.gather(ind[is_neg].reshape([N, -1]), 1, + relative_n_inds.data) + # shape [N] + p_inds = p_inds.squeeze(1) + n_inds = n_inds.squeeze(1) + return dist_ap, dist_an, p_inds, n_inds + + return dist_ap, dist_an diff --git a/ppcls/optimizer/__init__.py b/ppcls/optimizer/__init__.py index a440eac46..0a601f686 100644 --- a/ppcls/optimizer/__init__.py +++ b/ppcls/optimizer/__init__.py @@ -103,8 +103,11 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None): if optim_scope.endswith("Loss"): # optimizer for loss for m in model_list[i].sublayers(True): - if m.__class_name == optim_scope: + if m.__class__.__name__ == optim_scope: optim_model.append(m) + elif optim_scope == "model": + # opmizer for entire model + optim_model.append(model_list[i]) else: # opmizer for module in model, such as backbone, neck, head... if hasattr(model_list[i], optim_scope):