add center loss
parent
05770197c3
commit
41e1a86caf
|
@ -19,16 +19,29 @@ from __future__ import print_function
|
|||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from ppcls.arch.utils import get_param_attr_dict
|
||||
|
||||
|
||||
class FC(nn.Layer):
|
||||
def __init__(self, embedding_size, class_num):
|
||||
def __init__(self, embedding_size, class_num, **kwargs):
|
||||
super(FC, self).__init__()
|
||||
self.embedding_size = embedding_size
|
||||
self.class_num = class_num
|
||||
|
||||
weight_attr = paddle.ParamAttr(
|
||||
initializer=paddle.nn.initializer.XavierNormal())
|
||||
self.fc = paddle.nn.Linear(
|
||||
self.embedding_size, self.class_num, weight_attr=weight_attr)
|
||||
if 'weight_attr' in kwargs:
|
||||
weight_attr = get_param_attr_dict(kwargs['weight_attr'], None)
|
||||
|
||||
bias_attr = None
|
||||
if 'bias_attr' in kwargs:
|
||||
bias_attr = get_param_attr_dict(kwargs['bias_attr'], None)
|
||||
|
||||
self.fc = nn.Linear(
|
||||
self.embedding_size,
|
||||
self.class_num,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=bias_attr)
|
||||
|
||||
def forward(self, input, label=None):
|
||||
out = self.fc(input)
|
||||
|
|
|
@ -14,9 +14,11 @@
|
|||
|
||||
import six
|
||||
import types
|
||||
import paddle
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
from . import backbone
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
|
||||
def get_architectures():
|
||||
|
@ -31,8 +33,8 @@ def get_architectures():
|
|||
|
||||
|
||||
def get_blacklist_model_in_static_mode():
|
||||
from ppcls.arch.backbone import distilled_vision_transformer
|
||||
from ppcls.arch.backbone import vision_transformer
|
||||
from ppcls.arch.backbone import (distilled_vision_transformer,
|
||||
vision_transformer)
|
||||
blacklist = distilled_vision_transformer.__all__ + vision_transformer.__all__
|
||||
return blacklist
|
||||
|
||||
|
@ -51,3 +53,47 @@ def similar_architectures(name='', names=[], thresh=0.1, topk=10):
|
|||
scores.sort(key=lambda x: x[1], reverse=True)
|
||||
similar_names = [names[s[0]] for s in scores[:min(topk, len(scores))]]
|
||||
return similar_names
|
||||
|
||||
|
||||
def get_param_attr_dict(ParamAttr_config: Union[None, bool, Dict[str, Dict]]
|
||||
) -> Union[None, bool, paddle.ParamAttr]:
|
||||
"""parse ParamAttr from an dict
|
||||
|
||||
Args:
|
||||
ParamAttr_config (Union[bool, Dict[str, Dict]]): ParamAttr_config
|
||||
|
||||
Returns:
|
||||
Union[bool, paddle.ParamAttr]: Generated ParamAttr
|
||||
"""
|
||||
if ParamAttr_config is None:
|
||||
return None
|
||||
if isinstance(ParamAttr_config, bool):
|
||||
return ParamAttr_config
|
||||
ParamAttr_dict = {}
|
||||
if 'initiliazer' in ParamAttr_config:
|
||||
initiliazer_cfg = ParamAttr_config.get('initiliazer')
|
||||
if 'name' in initiliazer_cfg:
|
||||
initiliazer_name = initiliazer_cfg.pop('name')
|
||||
ParamAttr_dict['initiliazer'] = getattr(
|
||||
paddle.nn.initializer, initiliazer_name)(**initiliazer_cfg)
|
||||
else:
|
||||
raise ValueError(f"'name' must specified in initiliazer_cfg")
|
||||
if 'learning_rate' in ParamAttr_config:
|
||||
# NOTE: only support an single value now
|
||||
learning_rate_value = ParamAttr_config.get('learning_rate')
|
||||
if isinstance(learning_rate_value, (int, float)):
|
||||
ParamAttr_dict['learning_rate'] = learning_rate_value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"learning_rate_value must be float or int, but got {type(learning_rate_value)}"
|
||||
)
|
||||
if 'regularizer' in ParamAttr_config:
|
||||
regularizer_cfg = ParamAttr_config.get('regularizer')
|
||||
if 'name' in regularizer_cfg:
|
||||
# L1Decay or L2Decay
|
||||
regularizer_name = regularizer_cfg.pop('name')
|
||||
ParamAttr_dict['regularizer'] = getattr(
|
||||
paddle.regularizer, regularizer_name)(**regularizer_cfg)
|
||||
else:
|
||||
raise ValueError(f"'name' must specified in regularizer_cfg")
|
||||
return paddle.ParamAttr(**ParamAttr_dict)
|
||||
|
|
|
@ -0,0 +1,178 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: "./output/"
|
||||
device: "gpu"
|
||||
save_interval: 40
|
||||
eval_during_train: True
|
||||
eval_interval: 10
|
||||
epochs: 120
|
||||
print_batch_step: 20
|
||||
use_visualdl: False
|
||||
warmup_epoch_by_epoch: True
|
||||
eval_mode: "retrieval"
|
||||
re_ranking: True
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 256, 128]
|
||||
save_inference_dir: "./inference"
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: "RecModel"
|
||||
infer_output_key: "features"
|
||||
infer_add_softmax: False
|
||||
Backbone:
|
||||
name: "ResNet50_last_stage_stride1"
|
||||
pretrained: True
|
||||
stem_act: null
|
||||
BackboneStopLayer:
|
||||
name: "flatten"
|
||||
Neck:
|
||||
name: BNNeck
|
||||
num_features: &feat_dim 2048
|
||||
Head:
|
||||
name: "FC"
|
||||
embedding_size: *feat_dim
|
||||
class_num: &class_num 751
|
||||
weight_attr:
|
||||
initializer:
|
||||
name: Normal
|
||||
std: 0.001
|
||||
bias_attr: False
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
- TripletLossV3:
|
||||
weight: 1.0
|
||||
margin: 0.3
|
||||
normalize_feature: false
|
||||
- CenterLoss:
|
||||
weight: 0.0005
|
||||
num_classes: *class_num
|
||||
feat_dim: *feat_dim
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
||||
Optimizer:
|
||||
- Adam:
|
||||
scope: model
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs: [30, 60]
|
||||
values: [0.00035, 0.000035, 0.0000035]
|
||||
warmup_epoch: 10
|
||||
warmup_start_lr: 0.0000035
|
||||
warmup_epoch_by_epoch: True
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
coeff: 0.0005
|
||||
- SGD:
|
||||
sope: TripletLossV3
|
||||
lr:
|
||||
name: Constant
|
||||
learning_rate: 0.5
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: "VeriWild"
|
||||
image_root: "./dataset/market1501/bounding_box_train"
|
||||
cls_label_path: "./dataset/market1501/bounding_box_train.txt"
|
||||
relabel: True
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
size: [128, 256]
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- Pad:
|
||||
padding: 10
|
||||
- RandCropImage:
|
||||
size: [128, 256]
|
||||
scale: [ 0.8022, 0.8022 ]
|
||||
ratio: [ 0.5, 0.5 ]
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- RandomErasing:
|
||||
EPSILON: 0.5
|
||||
sl: 0.02
|
||||
sh: 0.4
|
||||
r1: 0.3
|
||||
mean: [0.4914, 0.4822, 0.4465]
|
||||
sampler:
|
||||
name: DistributedRandomIdentitySampler
|
||||
batch_size: 64
|
||||
num_instances: 4
|
||||
drop_last: True
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
Eval:
|
||||
Query:
|
||||
dataset:
|
||||
name: "VeriWild"
|
||||
image_root: "./dataset/market1501/query"
|
||||
cls_label_path: "./dataset/market1501/query.txt"
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
size: [128, 256]
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 128
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Gallery:
|
||||
dataset:
|
||||
name: "VeriWild"
|
||||
image_root: "./dataset/market1501/bounding_box_test"
|
||||
cls_label_path: "./dataset/market1501/bounding_box_test.txt"
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
size: [128, 256]
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 128
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Metric:
|
||||
Eval:
|
||||
- Recallk:
|
||||
topk: [1, 5]
|
||||
- mAP: {}
|
|
@ -298,12 +298,24 @@ class Engine(object):
|
|||
|
||||
self.max_iter = len(self.train_dataloader) - 1 if platform.system(
|
||||
) == "Windows" else len(self.train_dataloader)
|
||||
|
||||
if self.config["Global"].get("warmup_epoch_by_epoch", False):
|
||||
for i in range(len(self.lr_sch)):
|
||||
self.lr_sch[i].step()
|
||||
logger.info(
|
||||
"lr_sch step once before first epoch, when Global.warmup_epoch_by_epoch=True"
|
||||
)
|
||||
|
||||
for epoch_id in range(best_metric["epoch"] + 1,
|
||||
self.config["Global"]["epochs"] + 1):
|
||||
acc = 0.0
|
||||
# for one epoch train
|
||||
self.train_epoch_func(self, epoch_id, print_batch_step)
|
||||
|
||||
if self.config["Global"].get("warmup_epoch_by_epoch", False):
|
||||
for i in range(len(self.lr_sch)):
|
||||
self.lr_sch[i].step()
|
||||
|
||||
if self.use_dali:
|
||||
self.train_dataloader.reset()
|
||||
metric_msg = ", ".join([
|
||||
|
|
|
@ -16,6 +16,8 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import platform
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from ppcls.utils import logger
|
||||
|
||||
|
@ -49,34 +51,55 @@ def retrieval_eval(engine, epoch_id=0):
|
|||
metric_dict = {metric_key: 0.}
|
||||
else:
|
||||
metric_dict = dict()
|
||||
for block_idx, block_fea in enumerate(fea_blocks):
|
||||
similarity_matrix = paddle.matmul(
|
||||
block_fea, gallery_feas, transpose_y=True)
|
||||
if query_query_id is not None:
|
||||
query_id_block = query_id_blocks[block_idx]
|
||||
query_id_mask = (query_id_block != gallery_unique_id.t())
|
||||
reranking_flag = engine.config['Global'].get('re_ranking', False)
|
||||
logger.info(f"re_ranking={reranking_flag}")
|
||||
if not reranking_flag:
|
||||
for block_idx, block_fea in enumerate(fea_blocks):
|
||||
similarity_matrix = paddle.matmul(
|
||||
block_fea, gallery_feas, transpose_y=True)
|
||||
if query_query_id is not None:
|
||||
query_id_block = query_id_blocks[block_idx]
|
||||
query_id_mask = (query_id_block != gallery_unique_id.t())
|
||||
|
||||
image_id_block = image_id_blocks[block_idx]
|
||||
image_id_mask = (image_id_block != gallery_img_id.t())
|
||||
image_id_block = image_id_blocks[block_idx]
|
||||
image_id_mask = (image_id_block != gallery_img_id.t())
|
||||
|
||||
keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
|
||||
similarity_matrix = similarity_matrix * keep_mask.astype(
|
||||
"float32")
|
||||
else:
|
||||
keep_mask = None
|
||||
|
||||
metric_tmp = engine.eval_metric_func(similarity_matrix,
|
||||
image_id_blocks[block_idx],
|
||||
gallery_img_id, keep_mask)
|
||||
|
||||
for key in metric_tmp:
|
||||
if key not in metric_dict:
|
||||
metric_dict[key] = metric_tmp[key] * block_fea.shape[
|
||||
0] / len(query_feas)
|
||||
keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
|
||||
similarity_matrix = similarity_matrix * keep_mask.astype(
|
||||
"float32")
|
||||
else:
|
||||
metric_dict[key] += metric_tmp[key] * block_fea.shape[
|
||||
0] / len(query_feas)
|
||||
keep_mask = None
|
||||
|
||||
metric_tmp = engine.eval_metric_func(
|
||||
similarity_matrix, image_id_blocks[block_idx],
|
||||
gallery_img_id, keep_mask)
|
||||
|
||||
for key in metric_tmp:
|
||||
if key not in metric_dict:
|
||||
metric_dict[key] = metric_tmp[key] * block_fea.shape[
|
||||
0] / len(query_feas)
|
||||
else:
|
||||
metric_dict[key] += metric_tmp[key] * block_fea.shape[
|
||||
0] / len(query_feas)
|
||||
else:
|
||||
distmat = re_ranking(
|
||||
query_feas,
|
||||
gallery_feas,
|
||||
query_img_id,
|
||||
query_query_id,
|
||||
gallery_img_id,
|
||||
gallery_unique_id,
|
||||
k1=20,
|
||||
k2=6,
|
||||
lambda_value=0.3)
|
||||
cmc, mAP = eval_func(distmat,
|
||||
np.squeeze(query_img_id.numpy()),
|
||||
np.squeeze(gallery_img_id.numpy()),
|
||||
np.squeeze(query_query_id.numpy()),
|
||||
np.squeeze(gallery_unique_id.numpy()))
|
||||
for key in metric_tmp:
|
||||
metric_dict[key] = metric_tmp[key] * block_fea.shape[0] / len(
|
||||
query_feas)
|
||||
metric_info_list = []
|
||||
for key in metric_dict:
|
||||
if metric_key is None:
|
||||
|
@ -88,6 +111,162 @@ def retrieval_eval(engine, epoch_id=0):
|
|||
return metric_dict[metric_key]
|
||||
|
||||
|
||||
def re_ranking(queFea,
|
||||
galFea,
|
||||
k1=20,
|
||||
k2=6,
|
||||
lambda_value=0.5,
|
||||
local_distmat=None,
|
||||
only_local=False):
|
||||
# if feature vector is numpy, you should use 'paddle.tensor' transform it to tensor
|
||||
query_num = queFea.shape[0]
|
||||
all_num = query_num + galFea.shape[0]
|
||||
if only_local:
|
||||
original_dist = local_distmat
|
||||
else:
|
||||
feat = paddle.concat([queFea, galFea])
|
||||
logger.info('using GPU to compute original distance')
|
||||
|
||||
# L2 distance
|
||||
distmat = paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([all_num, all_num]) + \
|
||||
paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([all_num, all_num]).t()
|
||||
distmat = distmat.addmm(x=feat, y=feat.t(), alpha=-2.0, beta=1.0)
|
||||
# Cosine distance
|
||||
# distmat = paddle.matmul(queFea, galFea, transpose_y=True)
|
||||
# if query_query_id is not None:
|
||||
# query_id_mask = (queCid != galCid.t())
|
||||
# image_id_mask = (queId != galId.t())
|
||||
# keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
|
||||
# distmat = distmat * keep_mask.astype("float32")
|
||||
|
||||
original_dist = distmat.cpu().numpy()
|
||||
del feat
|
||||
if local_distmat is not None:
|
||||
original_dist = original_dist + local_distmat
|
||||
|
||||
gallery_num = original_dist.shape[0]
|
||||
original_dist = np.transpose(original_dist / np.max(original_dist, axis=0))
|
||||
V = np.zeros_like(original_dist).astype(np.float16)
|
||||
initial_rank = np.argsort(original_dist).astype(np.int32)
|
||||
logger.info('starting re_ranking')
|
||||
for i in range(all_num):
|
||||
# k-reciprocal neighbors
|
||||
forward_k_neigh_index = initial_rank[i, :k1 + 1]
|
||||
backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
|
||||
fi = np.where(backward_k_neigh_index == i)[0]
|
||||
k_reciprocal_index = forward_k_neigh_index[fi]
|
||||
k_reciprocal_expansion_index = k_reciprocal_index
|
||||
for j in range(len(k_reciprocal_index)):
|
||||
candidate = k_reciprocal_index[j]
|
||||
candidate_forward_k_neigh_index = initial_rank[candidate, :int(
|
||||
np.around(k1 / 2)) + 1]
|
||||
candidate_backward_k_neigh_index = initial_rank[
|
||||
candidate_forward_k_neigh_index, :int(np.around(k1 / 2)) + 1]
|
||||
fi_candidate = np.where(
|
||||
candidate_backward_k_neigh_index == candidate)[0]
|
||||
candidate_k_reciprocal_index = candidate_forward_k_neigh_index[
|
||||
fi_candidate]
|
||||
if len(
|
||||
np.intersect1d(candidate_k_reciprocal_index,
|
||||
k_reciprocal_index)) > 2 / 3 * len(
|
||||
candidate_k_reciprocal_index):
|
||||
k_reciprocal_expansion_index = np.append(
|
||||
k_reciprocal_expansion_index, candidate_k_reciprocal_index)
|
||||
|
||||
k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
|
||||
weight = np.exp(-original_dist[i, k_reciprocal_expansion_index])
|
||||
V[i, k_reciprocal_expansion_index] = weight / np.sum(weight)
|
||||
all_num_cost = time.time() - t
|
||||
original_dist = original_dist[:query_num, ]
|
||||
if k2 != 1:
|
||||
V_qe = np.zeros_like(V, dtype=np.float16)
|
||||
for i in range(all_num):
|
||||
V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
|
||||
V = V_qe
|
||||
del V_qe
|
||||
del initial_rank
|
||||
invIndex = []
|
||||
for i in range(gallery_num):
|
||||
invIndex.append(np.where(V[:, i] != 0)[0])
|
||||
|
||||
jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
|
||||
gallery_num_cost = time.time() - t
|
||||
for i in range(query_num):
|
||||
temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16)
|
||||
indNonZero = np.where(V[i, :] != 0)[0]
|
||||
indImages = [invIndex[ind] for ind in indNonZero]
|
||||
for j in range(len(indNonZero)):
|
||||
temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(
|
||||
V[i, indNonZero[j]], V[indImages[j], indNonZero[j]])
|
||||
jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
|
||||
|
||||
final_dist = jaccard_dist * (1 - lambda_value
|
||||
) + original_dist * lambda_value
|
||||
del original_dist
|
||||
del V
|
||||
del jaccard_dist
|
||||
final_dist = final_dist[:query_num, query_num:]
|
||||
query_num_cost = time.time() - t
|
||||
return final_dist
|
||||
|
||||
|
||||
def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
|
||||
"""Evaluation with market1501 metric
|
||||
Key: for each query identity, its gallery images from the same camera view are discarded.
|
||||
"""
|
||||
num_q, num_g = distmat.shape
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print("Note: number of gallery samples is quite small, got {}".format(
|
||||
num_g))
|
||||
indices = np.argsort(distmat, axis=1)
|
||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||
|
||||
# compute cmc curve for each query
|
||||
all_cmc = []
|
||||
all_AP = []
|
||||
num_valid_q = 0. # number of valid query
|
||||
for q_idx in range(num_q):
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
q_camid = q_camids[q_idx]
|
||||
|
||||
# remove gallery samples that have the same pid and camid with query
|
||||
order = indices[q_idx]
|
||||
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
|
||||
keep = np.invert(remove)
|
||||
|
||||
# compute cmc curve
|
||||
# binary vector, positions with value 1 are correct matches
|
||||
orig_cmc = matches[q_idx][keep]
|
||||
if not np.any(orig_cmc):
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
||||
cmc = orig_cmc.cumsum()
|
||||
cmc[cmc > 1] = 1
|
||||
|
||||
all_cmc.append(cmc[:max_rank])
|
||||
num_valid_q += 1.
|
||||
|
||||
# compute average precision
|
||||
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
||||
num_rel = orig_cmc.sum()
|
||||
tmp_cmc = orig_cmc.cumsum()
|
||||
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
|
||||
tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
|
||||
AP = tmp_cmc.sum() / num_rel
|
||||
all_AP.append(AP)
|
||||
|
||||
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
|
||||
|
||||
all_cmc = np.asarray(all_cmc).astype(np.float32)
|
||||
all_cmc = all_cmc.sum(0) / num_valid_q
|
||||
mAP = np.mean(all_AP)
|
||||
|
||||
return all_cmc, mAP
|
||||
|
||||
|
||||
def cal_feature(engine, name='gallery'):
|
||||
all_feas = None
|
||||
all_image_id = None
|
||||
|
|
|
@ -63,9 +63,27 @@ def train_epoch(engine, epoch_id, print_batch_step):
|
|||
loss_dict["loss"].backward()
|
||||
for i in range(len(engine.optimizer)):
|
||||
engine.optimizer[i].step()
|
||||
|
||||
if hasattr(engine.model.neck, 'bn'):
|
||||
engine.model.neck.bn.bias.grad.set_value(
|
||||
paddle.zeros_like(engine.model.neck.bn.bias.grad))
|
||||
|
||||
# clear grad
|
||||
for i in range(len(engine.optimizer)):
|
||||
# manually scale up grad of center_loss
|
||||
if i == 1:
|
||||
for j in range(len(engine.train_loss_func.loss_func)):
|
||||
if len(engine.train_loss_func.loss_func[j].parameters(
|
||||
)) == 0:
|
||||
continue
|
||||
for param in engine.train_loss_func.loss_func[
|
||||
j].parameters():
|
||||
if hasattr(param, 'grad') and param.grad is not None:
|
||||
param.grad.set_value(param.grad * (
|
||||
1.0 / engine.train_loss_func.loss_weight[j]))
|
||||
|
||||
engine.optimizer[i].clear_grad()
|
||||
|
||||
# step lr
|
||||
for i in range(len(engine.lr_sch)):
|
||||
engine.lr_sch[i].step()
|
||||
|
|
|
@ -11,7 +11,7 @@ from .emlloss import EmlLoss
|
|||
from .msmloss import MSMLoss
|
||||
from .npairsloss import NpairsLoss
|
||||
from .trihardloss import TriHardLoss
|
||||
from .triplet import TripletLoss, TripletLossV2
|
||||
from .triplet import TripletLoss, TripletLossV2, TripletLossV3
|
||||
from .supconloss import SupConLoss
|
||||
from .pairwisecosface import PairwiseCosface
|
||||
from .dmlloss import DMLLoss
|
||||
|
|
|
@ -1,54 +1,74 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class CenterLoss(nn.Layer):
|
||||
def __init__(self, num_classes=5013, feat_dim=2048):
|
||||
"""Center loss class
|
||||
|
||||
Args:
|
||||
num_classes (int): number of classes.
|
||||
feat_dim (int): number of feature dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes: int, feat_dim: int):
|
||||
super(CenterLoss, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.feat_dim = feat_dim
|
||||
self.centers = paddle.randn(
|
||||
shape=[self.num_classes, self.feat_dim]).astype(
|
||||
"float64") #random center
|
||||
random_init_centers = paddle.randn(
|
||||
shape=[self.num_classes, self.feat_dim])
|
||||
self.centers = self.create_parameter(
|
||||
shape=(self.num_classes, self.feat_dim),
|
||||
default_initializer=nn.initializer.Assign(random_init_centers))
|
||||
self.add_parameter("centers", self.centers)
|
||||
|
||||
def __call__(self, input, target):
|
||||
def __call__(self, input: Dict[str, paddle.Tensor],
|
||||
target: paddle.Tensor) -> Dict[str, paddle.Tensor]:
|
||||
"""compute center loss.
|
||||
|
||||
Args:
|
||||
input (Dict[str, paddle.Tensor]): {'features': (batch_size, feature_dim), ...}.
|
||||
target (paddle.Tensor): ground truth label with shape (batch_size, ).
|
||||
|
||||
Returns:
|
||||
Dict[str, paddle.Tensor]: {'CenterLoss': loss}.
|
||||
"""
|
||||
inputs: network output: {"features: xxx", "logits": xxxx}
|
||||
target: image label
|
||||
"""
|
||||
feats = input["features"]
|
||||
feats = input['backbone']
|
||||
labels = target
|
||||
|
||||
# squeeze labels to shape (batch_size, )
|
||||
if labels.ndim >= 2 and labels.shape[-1] == 1:
|
||||
labels = paddle.squeeze(labels, axis=[-1])
|
||||
|
||||
batch_size = feats.shape[0]
|
||||
distmat = paddle.pow(feats, 2).sum(axis=1, keepdim=True).expand([batch_size, self.num_classes]) + \
|
||||
paddle.pow(self.centers, 2).sum(axis=1, keepdim=True).expand([self.num_classes, batch_size]).t()
|
||||
distmat = distmat.addmm(x=feats, y=self.centers.t(), beta=1, alpha=-2)
|
||||
|
||||
#calc feat * feat
|
||||
dist1 = paddle.sum(paddle.square(feats), axis=1, keepdim=True)
|
||||
dist1 = paddle.expand(dist1, [batch_size, self.num_classes])
|
||||
|
||||
#dist2 of centers
|
||||
dist2 = paddle.sum(paddle.square(self.centers), axis=1,
|
||||
keepdim=True) #num_classes
|
||||
dist2 = paddle.expand(dist2,
|
||||
[self.num_classes, batch_size]).astype("float64")
|
||||
dist2 = paddle.transpose(dist2, [1, 0])
|
||||
|
||||
#first x * x + y * y
|
||||
distmat = paddle.add(dist1, dist2)
|
||||
tmp = paddle.matmul(feats, paddle.transpose(self.centers, [1, 0]))
|
||||
distmat = distmat - 2.0 * tmp
|
||||
|
||||
#generate the mask
|
||||
classes = paddle.arange(self.num_classes).astype("int64")
|
||||
labels = paddle.expand(
|
||||
paddle.unsqueeze(labels, 1), (batch_size, self.num_classes))
|
||||
mask = paddle.equal(
|
||||
paddle.expand(classes, [batch_size, self.num_classes]),
|
||||
labels).astype("float64") #get mask
|
||||
|
||||
dist = paddle.multiply(distmat, mask)
|
||||
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
|
||||
classes = paddle.arange(self.num_classes).astype(labels.dtype)
|
||||
labels = labels.unsqueeze(1).expand([batch_size, self.num_classes])
|
||||
mask = labels.equal(classes.expand([batch_size, self.num_classes]))
|
||||
|
||||
dist = distmat * mask.astype(feats.dtype)
|
||||
loss = dist.clip(min=1e-12, max=1e+12).sum() / batch_size
|
||||
# return loss
|
||||
return {'CenterLoss': loss}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
@ -135,3 +136,122 @@ class TripletLoss(nn.Layer):
|
|||
y = paddle.ones_like(dist_an)
|
||||
loss = self.ranking_loss(dist_an, dist_ap, y)
|
||||
return {"TripletLoss": loss}
|
||||
|
||||
|
||||
class TripletLossV3(nn.Layer):
|
||||
"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
|
||||
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
|
||||
Loss for Person Re-Identification'."""
|
||||
|
||||
def __init__(self, margin=None, normalize_feature=False):
|
||||
super(TripletLossV3, self).__init__()
|
||||
self.normalize_feature = normalize_feature
|
||||
self.margin = margin
|
||||
if margin is not None:
|
||||
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
|
||||
else:
|
||||
self.ranking_loss = nn.SoftMarginLoss()
|
||||
|
||||
def forward(self, input, target):
|
||||
global_feat = input["backbone"]
|
||||
if self.normalize_feature:
|
||||
global_feat = self._normalize(global_feat, axis=-1)
|
||||
dist_mat = self._euclidean_dist(global_feat, global_feat)
|
||||
dist_ap, dist_an = self._hard_example_mining(dist_mat, target)
|
||||
y = paddle.ones_like(dist_an)
|
||||
if self.margin is not None:
|
||||
loss = self.ranking_loss(dist_an, dist_ap, y)
|
||||
|
||||
return {"TripletLossV3": loss}
|
||||
|
||||
def _normalize(self, x: paddle.Tensor, axis: int=-1) -> paddle.Tensor:
|
||||
"""Normalizing to unit length along the specified dimension.
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): (batch_size, feature_dim)
|
||||
axis (int, optional): normalization dim. Defaults to -1.
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: (batch_size, feature_dim)
|
||||
"""
|
||||
x = 1. * x / (paddle.norm(
|
||||
x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
|
||||
return x
|
||||
|
||||
def _euclidean_dist(self, x: paddle.Tensor,
|
||||
y: paddle.Tensor) -> paddle.Tensor:
|
||||
"""compute euclidean distance between two batched vectors
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): (N, feature_dim)
|
||||
y (paddle.Tensor): (M, feature_dim)
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: (N, M)
|
||||
"""
|
||||
m, n = x.shape[0], y.shape[0]
|
||||
d = x.shape[1]
|
||||
xx = paddle.pow(x, 2).sum(1, keepdim=True).expand([m, n])
|
||||
yy = paddle.pow(y, 2).sum(1, keepdim=True).expand([n, m]).t()
|
||||
dist = xx + yy
|
||||
dist = dist.addmm(x, y.t(), alpha=-2, beta=1)
|
||||
# dist = dist - 2*(x@y.t())
|
||||
dist = dist.clip(min=1e-12).sqrt() # for numerical stability
|
||||
return dist
|
||||
|
||||
def _hard_example_mining(
|
||||
self,
|
||||
dist_mat: paddle.Tensor,
|
||||
labels: paddle.Tensor,
|
||||
return_inds: bool=False) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""For each anchor, find the hardest positive and negative sample.
|
||||
|
||||
Args:
|
||||
dist_mat (paddle.Tensor): pair wise distance between samples, [N, N]
|
||||
labels (paddle.Tensor): labels, [N, ]
|
||||
return_inds (bool, optional): whether to return the indices . Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tuple[paddle.Tensor, paddle.Tensor]: [(N, ), (N, )]
|
||||
|
||||
NOTE: Only consider the case in which all labels have same num of samples,
|
||||
thus we can cope with all anchors in parallel.
|
||||
"""
|
||||
assert len(dist_mat.shape) == 2
|
||||
assert dist_mat.shape[0] == dist_mat.shape[1]
|
||||
N = dist_mat.shape[0]
|
||||
|
||||
# shape [N, N]
|
||||
is_pos = labels.expand([N, N]).equal(labels.expand([N, N]).t())
|
||||
is_neg = labels.expand([N, N]).not_equal(labels.expand([N, N]).t())
|
||||
|
||||
# `dist_ap` means distance(anchor, positive)
|
||||
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
|
||||
dist_ap = paddle.max(dist_mat[is_pos].reshape([N, -1]),
|
||||
1,
|
||||
keepdim=True)
|
||||
# `dist_an` means distance(anchor, negative)
|
||||
# both `dist_an` and `relative_n_inds` with shape [N, 1]
|
||||
dist_an = paddle.min(dist_mat[is_neg].reshape([N, -1]),
|
||||
1,
|
||||
keepdim=True)
|
||||
# shape [N]
|
||||
dist_ap = dist_ap.squeeze(1)
|
||||
dist_an = dist_an.squeeze(1)
|
||||
|
||||
if return_inds:
|
||||
# shape [N, N]
|
||||
ind = (labels.new().resize_as_(labels)
|
||||
.copy_(paddle.arange(0, N).long())
|
||||
.unsqueeze(0).expand(N, N))
|
||||
# shape [N, 1]
|
||||
p_inds = paddle.gather(ind[is_pos].reshape([N, -1]), 1,
|
||||
relative_p_inds.data)
|
||||
n_inds = paddle.gather(ind[is_neg].reshape([N, -1]), 1,
|
||||
relative_n_inds.data)
|
||||
# shape [N]
|
||||
p_inds = p_inds.squeeze(1)
|
||||
n_inds = n_inds.squeeze(1)
|
||||
return dist_ap, dist_an, p_inds, n_inds
|
||||
|
||||
return dist_ap, dist_an
|
||||
|
|
|
@ -103,8 +103,11 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
|||
if optim_scope.endswith("Loss"):
|
||||
# optimizer for loss
|
||||
for m in model_list[i].sublayers(True):
|
||||
if m.__class_name == optim_scope:
|
||||
if m.__class__.__name__ == optim_scope:
|
||||
optim_model.append(m)
|
||||
elif optim_scope == "model":
|
||||
# opmizer for entire model
|
||||
optim_model.append(model_list[i])
|
||||
else:
|
||||
# opmizer for module in model, such as backbone, neck, head...
|
||||
if hasattr(model_list[i], optim_scope):
|
||||
|
|
Loading…
Reference in New Issue