refine code
parent
97e8abc3db
commit
674447f63e
|
@ -48,6 +48,7 @@ Loss:
|
|||
weight: 1.0
|
||||
margin: 0.3
|
||||
normalize_feature: false
|
||||
feat_from: "backbone"
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
|
|
@ -61,6 +61,7 @@ Loss:
|
|||
weight: 1.0
|
||||
margin: 0.3
|
||||
normalize_feature: false
|
||||
feat_from: "backbone"
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
|
|
@ -40,7 +40,7 @@ Arch:
|
|||
initializer:
|
||||
name: Constant
|
||||
value: 0.0
|
||||
learning_rate: 1.0e-20 # TODO: Temporarily set lr small enough to freeze the bias
|
||||
learning_rate: 1.0e-20 # NOTE: Temporarily set lr small enough to freeze the bias
|
||||
Head:
|
||||
name: "FC"
|
||||
embedding_size: *feat_dim
|
||||
|
@ -57,14 +57,16 @@ Loss:
|
|||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
- TripletLossV3:
|
||||
- TripletLossV2:
|
||||
weight: 1.0
|
||||
margin: 0.3
|
||||
normalize_feature: false
|
||||
feat_from: "backbone"
|
||||
- CenterLoss:
|
||||
weight: 0.0005
|
||||
num_classes: *class_num
|
||||
feat_dim: *feat_dim
|
||||
feat_from: "backbone"
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
|
|
@ -28,12 +28,17 @@ class CenterLoss(nn.Layer):
|
|||
Args:
|
||||
num_classes (int): number of classes.
|
||||
feat_dim (int): number of feature dimensions.
|
||||
feat_from (str): features from backbone or neck
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes: int, feat_dim: int):
|
||||
def __init__(self,
|
||||
num_classes: int,
|
||||
feat_dim: int,
|
||||
feat_from: str='backbone'):
|
||||
super(CenterLoss, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.feat_dim = feat_dim
|
||||
self.feat_from = feat_from
|
||||
random_init_centers = paddle.randn(
|
||||
shape=[self.num_classes, self.feat_dim])
|
||||
self.centers = self.create_parameter(
|
||||
|
@ -52,7 +57,7 @@ class CenterLoss(nn.Layer):
|
|||
Returns:
|
||||
Dict[str, paddle.Tensor]: {'CenterLoss': loss}.
|
||||
"""
|
||||
feats = input['backbone']
|
||||
feats = input[self.feat_from]
|
||||
labels = target
|
||||
|
||||
# squeeze labels to shape (batch_size, )
|
||||
|
|
|
@ -1,7 +1,20 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
@ -13,9 +26,13 @@ class TripletLossV2(nn.Layer):
|
|||
margin (float): margin for triplet.
|
||||
"""
|
||||
|
||||
def __init__(self, margin=0.5, normalize_feature=True):
|
||||
def __init__(self,
|
||||
margin=0.5,
|
||||
normalize_feature=True,
|
||||
feat_from='backbone'):
|
||||
super(TripletLossV2, self).__init__()
|
||||
self.margin = margin
|
||||
self.feat_from = feat_from
|
||||
self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin)
|
||||
self.normalize_feature = normalize_feature
|
||||
|
||||
|
@ -25,7 +42,7 @@ class TripletLossV2(nn.Layer):
|
|||
inputs: feature matrix with shape (batch_size, feat_dim)
|
||||
target: ground truth labels with shape (num_classes)
|
||||
"""
|
||||
inputs = input["backbone"]
|
||||
inputs = input[self.feat_from]
|
||||
|
||||
if self.normalize_feature:
|
||||
inputs = 1. * inputs / (paddle.expand_as(
|
||||
|
@ -136,122 +153,3 @@ class TripletLoss(nn.Layer):
|
|||
y = paddle.ones_like(dist_an)
|
||||
loss = self.ranking_loss(dist_an, dist_ap, y)
|
||||
return {"TripletLoss": loss}
|
||||
|
||||
|
||||
class TripletLossV3(nn.Layer):
|
||||
"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
|
||||
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
|
||||
Loss for Person Re-Identification'."""
|
||||
|
||||
def __init__(self, margin=None, normalize_feature=False):
|
||||
super(TripletLossV3, self).__init__()
|
||||
self.normalize_feature = normalize_feature
|
||||
self.margin = margin
|
||||
if margin is not None:
|
||||
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
|
||||
else:
|
||||
self.ranking_loss = nn.SoftMarginLoss()
|
||||
|
||||
def forward(self, input, target):
|
||||
global_feat = input["backbone"]
|
||||
if self.normalize_feature:
|
||||
global_feat = self._normalize(global_feat, axis=-1)
|
||||
dist_mat = self._euclidean_dist(global_feat, global_feat)
|
||||
dist_ap, dist_an = self._hard_example_mining(dist_mat, target)
|
||||
y = paddle.ones_like(dist_an)
|
||||
if self.margin is not None:
|
||||
loss = self.ranking_loss(dist_an, dist_ap, y)
|
||||
|
||||
return {"TripletLossV3": loss}
|
||||
|
||||
def _normalize(self, x: paddle.Tensor, axis: int=-1) -> paddle.Tensor:
|
||||
"""Normalizing to unit length along the specified dimension.
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): (batch_size, feature_dim)
|
||||
axis (int, optional): normalization dim. Defaults to -1.
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: (batch_size, feature_dim)
|
||||
"""
|
||||
x = 1. * x / (paddle.norm(
|
||||
x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
|
||||
return x
|
||||
|
||||
def _euclidean_dist(self, x: paddle.Tensor,
|
||||
y: paddle.Tensor) -> paddle.Tensor:
|
||||
"""compute euclidean distance between two batched vectors
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): (N, feature_dim)
|
||||
y (paddle.Tensor): (M, feature_dim)
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: (N, M)
|
||||
"""
|
||||
m, n = x.shape[0], y.shape[0]
|
||||
d = x.shape[1]
|
||||
xx = paddle.pow(x, 2).sum(1, keepdim=True).expand([m, n])
|
||||
yy = paddle.pow(y, 2).sum(1, keepdim=True).expand([n, m]).t()
|
||||
dist = xx + yy
|
||||
dist = dist.addmm(x, y.t(), alpha=-2, beta=1)
|
||||
# dist = dist - 2*(x@y.t())
|
||||
dist = dist.clip(min=1e-12).sqrt() # for numerical stability
|
||||
return dist
|
||||
|
||||
def _hard_example_mining(
|
||||
self,
|
||||
dist_mat: paddle.Tensor,
|
||||
labels: paddle.Tensor,
|
||||
return_inds: bool=False) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""For each anchor, find the hardest positive and negative sample.
|
||||
|
||||
Args:
|
||||
dist_mat (paddle.Tensor): pair wise distance between samples, [N, N]
|
||||
labels (paddle.Tensor): labels, [N, ]
|
||||
return_inds (bool, optional): whether to return the indices . Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tuple[paddle.Tensor, paddle.Tensor]: [(N, ), (N, )]
|
||||
|
||||
NOTE: Only consider the case in which all labels have same num of samples,
|
||||
thus we can cope with all anchors in parallel.
|
||||
"""
|
||||
assert len(dist_mat.shape) == 2
|
||||
assert dist_mat.shape[0] == dist_mat.shape[1]
|
||||
N = dist_mat.shape[0]
|
||||
|
||||
# shape [N, N]
|
||||
is_pos = labels.expand([N, N]).equal(labels.expand([N, N]).t())
|
||||
is_neg = labels.expand([N, N]).not_equal(labels.expand([N, N]).t())
|
||||
|
||||
# `dist_ap` means distance(anchor, positive)
|
||||
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
|
||||
dist_ap = paddle.max(dist_mat[is_pos].reshape([N, -1]),
|
||||
1,
|
||||
keepdim=True)
|
||||
# `dist_an` means distance(anchor, negative)
|
||||
# both `dist_an` and `relative_n_inds` with shape [N, 1]
|
||||
dist_an = paddle.min(dist_mat[is_neg].reshape([N, -1]),
|
||||
1,
|
||||
keepdim=True)
|
||||
# shape [N]
|
||||
dist_ap = dist_ap.squeeze(1)
|
||||
dist_an = dist_an.squeeze(1)
|
||||
|
||||
if return_inds:
|
||||
# shape [N, N]
|
||||
ind = (labels.new().resize_as_(labels)
|
||||
.copy_(paddle.arange(0, N).long())
|
||||
.unsqueeze(0).expand(N, N))
|
||||
# shape [N, 1]
|
||||
p_inds = paddle.gather(ind[is_pos].reshape([N, -1]), 1,
|
||||
relative_p_inds.data)
|
||||
n_inds = paddle.gather(ind[is_neg].reshape([N, -1]), 1,
|
||||
relative_n_inds.data)
|
||||
# shape [N]
|
||||
p_inds = p_inds.squeeze(1)
|
||||
n_inds = n_inds.squeeze(1)
|
||||
return dist_ap, dist_an, p_inds, n_inds
|
||||
|
||||
return dist_ap, dist_an
|
||||
|
|
|
@ -46,7 +46,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
|
|||
def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
||||
optim_config = copy.deepcopy(config)
|
||||
if isinstance(optim_config, dict):
|
||||
# convert {'name': xxx, **optim_cfg} to [{'name': {'scope': xxx, **optim_cfg}}]
|
||||
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
|
||||
optim_name = optim_config.pop("name")
|
||||
optim_config: List[Dict[str, Dict]] = [{
|
||||
optim_name: {
|
||||
|
@ -60,20 +60,19 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
|||
"""NOTE:
|
||||
Currently only support optim objets below.
|
||||
1. single optimizer config.
|
||||
2. model(entire Arch), backbone, neck, head.
|
||||
3. loss(entire Loss), specific loss listed in ppcls/loss/__init__.py.
|
||||
2. next level uner Arch, such as Arch.backbone, Arch.neck, Arch.head.
|
||||
3. loss which has parameters, such as CenterLoss.
|
||||
"""
|
||||
for optim_item in optim_config:
|
||||
# optim_cfg = {optim_name: {'scope': xxx, **optim_cfg}}
|
||||
# optim_cfg = {optim_name: {scope: xxx, **optim_cfg}}
|
||||
# step1 build lr
|
||||
optim_name = list(optim_item.keys())[0] # get optim_name
|
||||
optim_scope_list = optim_item[optim_name].pop('scope').split(
|
||||
' ') # get optim_scope list
|
||||
optim_scope = optim_item[optim_name].pop('scope') # get optim_scope
|
||||
optim_cfg = optim_item[optim_name] # get optim_cfg
|
||||
|
||||
lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch)
|
||||
logger.info("build lr ({}) for scope ({}) success..".format(
|
||||
lr.__class__.__name__, optim_scope_list))
|
||||
logger.debug("build lr ({}) for scope ({}) success..".format(
|
||||
lr, optim_scope))
|
||||
# step2 build regularization
|
||||
if 'regularizer' in optim_cfg and optim_cfg['regularizer'] is not None:
|
||||
if 'weight_decay' in optim_cfg:
|
||||
|
@ -84,14 +83,12 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
|||
reg_name = reg_config.pop('name') + 'Decay'
|
||||
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
|
||||
optim_cfg["weight_decay"] = reg
|
||||
logger.info("build regularizer ({}) for scope ({}) success..".
|
||||
format(reg.__class__.__name__, optim_scope_list))
|
||||
logger.debug("build regularizer ({}) for scope ({}) success..".
|
||||
format(reg, optim_scope))
|
||||
# step3 build optimizer
|
||||
if 'clip_norm' in optim_cfg:
|
||||
clip_norm = optim_cfg.pop('clip_norm')
|
||||
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
|
||||
logger.info("build gradclip ({}) for scope ({}) success..".format(
|
||||
grad_clip.__class__.__name__, optim_scope_list))
|
||||
else:
|
||||
grad_clip = None
|
||||
optim_model = []
|
||||
|
@ -104,34 +101,30 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
|||
return optim, lr
|
||||
|
||||
# for dynamic graph
|
||||
for scope in optim_scope_list:
|
||||
if scope == "all":
|
||||
optim_model += model_list
|
||||
elif scope == "model":
|
||||
optim_model += [model_list[0], ]
|
||||
elif scope in ["backbone", "neck", "head"]:
|
||||
optim_model += [getattr(model_list[0], scope, None), ]
|
||||
elif scope == "loss":
|
||||
optim_model += [model_list[1], ]
|
||||
for i in range(len(model_list)):
|
||||
if len(model_list[i].parameters()) == 0:
|
||||
continue
|
||||
if optim_scope == "all":
|
||||
# optimizer for all
|
||||
optim_model.append(model_list[i])
|
||||
else:
|
||||
optim_model += [
|
||||
model_list[1].loss_func[i]
|
||||
for i in range(len(model_list[1].loss_func))
|
||||
if model_list[1].loss_func[i].__class__.__name__ == scope
|
||||
]
|
||||
# remove invalid items
|
||||
optim_model = [
|
||||
optim_model[i] for i in range(len(optim_model))
|
||||
if (optim_model[i] is not None
|
||||
) and (len(optim_model[i].parameters()) > 0)
|
||||
]
|
||||
assert len(optim_model) > 0, \
|
||||
f"optim_model is empty for optim_scope({optim_scope_list})"
|
||||
if optim_scope.endswith("Loss"):
|
||||
# optimizer for loss
|
||||
for m in model_list[i].sublayers(True):
|
||||
if m.__class__.__name__ == optim_scope:
|
||||
optim_model.append(m)
|
||||
else:
|
||||
# opmizer for module in model, such as backbone, neck, head...
|
||||
if hasattr(model_list[i], optim_scope):
|
||||
optim_model.append(getattr(model_list[i], optim_scope))
|
||||
|
||||
assert len(optim_model) == 1, \
|
||||
"Invalid optim model for optim scope({}), number of optim_model={}".format(optim_scope, len(optim_model))
|
||||
optim = getattr(optimizer, optim_name)(
|
||||
learning_rate=lr, grad_clip=grad_clip,
|
||||
**optim_cfg)(model_list=optim_model)
|
||||
logger.info("build optimizer ({}) for scope ({}) success..".format(
|
||||
optim.__class__.__name__, optim_scope_list))
|
||||
logger.debug("build optimizer ({}) for scope ({}) success..".format(
|
||||
optim, optim_scope))
|
||||
optim_list.append(optim)
|
||||
lr_list.append(lr)
|
||||
return optim_list, lr_list
|
||||
|
|
Loading…
Reference in New Issue