mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
update freeze layer
update preciseBN update circle loss with metric learning and cross entropy loss form update loss call methods
This commit is contained in:
parent
6a8961ce48
commit
4d2fa28dbb
@ -173,6 +173,11 @@ _C.TEST = CN()
|
|||||||
_C.TEST.EVAL_PERIOD = 50
|
_C.TEST.EVAL_PERIOD = 50
|
||||||
_C.TEST.IMS_PER_BATCH = 128
|
_C.TEST.IMS_PER_BATCH = 128
|
||||||
|
|
||||||
|
# Precise BN
|
||||||
|
_C.TEST.PRECISE_BN = CN()
|
||||||
|
_C.TEST.PRECISE_BN.ENABLED = False
|
||||||
|
_C.TEST.PRECISE_BN.DATASET = 'Market1501'
|
||||||
|
_C.TEST.PRECISE_BN.NUM_ITER = 300
|
||||||
# ---------------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------------- #
|
||||||
# Misc options
|
# Misc options
|
||||||
# ---------------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------------- #
|
||||||
|
@ -257,17 +257,21 @@ class DefaultTrainer(SimpleTrainer):
|
|||||||
ret = [
|
ret = [
|
||||||
hooks.IterationTimer(),
|
hooks.IterationTimer(),
|
||||||
hooks.LRScheduler(self.optimizer, self.scheduler),
|
hooks.LRScheduler(self.optimizer, self.scheduler),
|
||||||
# hooks.PreciseBN(
|
hooks.PreciseBN(
|
||||||
# # Run at the same freq as (but before) evaluation.
|
# Run at the same freq as (but before) evaluation.
|
||||||
# cfg.TEST.EVAL_PERIOD,
|
cfg.TEST.EVAL_PERIOD,
|
||||||
# self.model,
|
self.model,
|
||||||
# # Build a new data loader to not affect training
|
# Build a new data loader to not affect training
|
||||||
# self.build_train_loader(cfg),
|
self.build_train_loader(cfg),
|
||||||
# cfg.TEST.PRECISE_BN.NUM_ITER,
|
cfg.TEST.PRECISE_BN.NUM_ITER,
|
||||||
# )
|
)
|
||||||
# if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
|
if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model)
|
||||||
# else None,
|
else None,
|
||||||
hooks.FreezeLayer(self.model, cfg.MODEL.OPEN_LAYERS, cfg.SOLVER.FREEZE_ITERS)
|
hooks.FreezeLayer(
|
||||||
|
self.model,
|
||||||
|
cfg.MODEL.OPEN_LAYERS,
|
||||||
|
cfg.SOLVER.FREEZE_ITERS)
|
||||||
|
if cfg.MODEL.OPEN_LAYERS != '' and cfg.SOLVER.FREEZE_ITERS > 0 else None,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Do PreciseBN before checkpointer, because it updates the model and need to
|
# Do PreciseBN before checkpointer, because it updates the model and need to
|
||||||
|
@ -430,9 +430,7 @@ class FreezeLayer(HookBase):
|
|||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
self.freeze_iters = freeze_iters
|
self.freeze_iters = freeze_iters
|
||||||
if open_layer_names == '':
|
if isinstance(open_layer_names, str):
|
||||||
self.freeze_iters = -1
|
|
||||||
elif isinstance(open_layer_names, str):
|
|
||||||
open_layer_names = [open_layer_names]
|
open_layer_names = [open_layer_names]
|
||||||
|
|
||||||
self.open_layer_names = open_layer_names
|
self.open_layer_names = open_layer_names
|
||||||
|
@ -38,7 +38,8 @@ class BNneckHead(nn.Module):
|
|||||||
return bn_feat
|
return bn_feat
|
||||||
# training
|
# training
|
||||||
pred_class_logits = self.classifier(bn_feat)
|
pred_class_logits = self.classifier(bn_feat)
|
||||||
return pred_class_logits, global_feat
|
# return pred_class_logits, global_feat
|
||||||
|
return pred_class_logits, bn_feat
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def losses(cls, cfg, pred_class_logits, global_features, gt_classes, prefix='') -> dict:
|
def losses(cls, cfg, pred_class_logits, global_features, gt_classes, prefix='') -> dict:
|
||||||
|
@ -48,20 +48,19 @@ class CircleHead(nn.Module):
|
|||||||
if not self.training:
|
if not self.training:
|
||||||
return bn_feat
|
return bn_feat
|
||||||
|
|
||||||
cos_sim = F.linear(F.normalize(bn_feat), F.normalize(self.weight))
|
sim_mat = F.linear(F.normalize(bn_feat), F.normalize(self.weight))
|
||||||
alpha_p = F.relu(-cos_sim + 1 + self._m)
|
alpha_p = F.relu(-sim_mat.detach() + 1 + self._m)
|
||||||
alpha_n = F.relu(cos_sim + self._m)
|
alpha_n = F.relu(sim_mat.detach() + self._m)
|
||||||
margin_p = 1 - self._m
|
delta_p = 1 - self._m
|
||||||
margin_n = self._m
|
delta_n = self._m
|
||||||
|
|
||||||
sp = alpha_p * (cos_sim - margin_p)
|
s_p = self._s * alpha_p * (sim_mat - delta_p)
|
||||||
sn = alpha_n * (cos_sim - margin_n)
|
s_n = self._s * alpha_n * (sim_mat - delta_n)
|
||||||
|
|
||||||
one_hot = torch.zeros(cos_sim.size()).to(targets.device)
|
one_hot = torch.zeros(sim_mat.size()).to(targets.device)
|
||||||
one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
|
one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
|
||||||
|
|
||||||
pred_class_logits = one_hot * sp + ((1.0 - one_hot) * sn)
|
pred_class_logits = one_hot * s_p + (1.0 - one_hot) * s_n
|
||||||
pred_class_logits *= self._s
|
|
||||||
|
|
||||||
return pred_class_logits, global_feat
|
return pred_class_logits, global_feat
|
||||||
|
|
||||||
|
@ -7,9 +7,9 @@
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .build import REID_HEADS_REGISTRY
|
from .build import REID_HEADS_REGISTRY
|
||||||
from ..losses import CrossEntropyLoss, TripletLoss
|
from .. import losses as Loss
|
||||||
from ..model_utils import weights_init_classifier, weights_init_kaiming
|
from ..model_utils import weights_init_classifier
|
||||||
from ...layers import bn_no_bias, Flatten
|
from ...layers import Flatten
|
||||||
|
|
||||||
|
|
||||||
@REID_HEADS_REGISTRY.register()
|
@REID_HEADS_REGISTRY.register()
|
||||||
@ -41,11 +41,8 @@ class LinearHead(nn.Module):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def losses(cls, cfg, pred_class_logits, global_features, gt_classes, prefix='') -> dict:
|
def losses(cls, cfg, pred_class_logits, global_features, gt_classes, prefix='') -> dict:
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
if "CrossEntropyLoss" in cfg.MODEL.LOSSES.NAME and pred_class_logits is not None:
|
for loss_name in cfg.MODEL.LOSSES.NAME:
|
||||||
loss = CrossEntropyLoss(cfg)(pred_class_logits, gt_classes)
|
loss = getattr(Loss, loss_name)(cfg)(pred_class_logits, global_features, gt_classes)
|
||||||
loss_dict.update(loss)
|
|
||||||
if "TripletLoss" in cfg.MODEL.LOSSES.NAME and global_features is not None:
|
|
||||||
loss = TripletLoss(cfg)(global_features, gt_classes)
|
|
||||||
loss_dict.update(loss)
|
loss_dict.update(loss)
|
||||||
# rename
|
# rename
|
||||||
name_loss_dict = {}
|
name_loss_dict = {}
|
||||||
|
@ -5,4 +5,4 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from .cross_entroy_loss import CrossEntropyLoss
|
from .cross_entroy_loss import CrossEntropyLoss
|
||||||
from .metric_loss import TripletLoss
|
from .metric_loss import *
|
||||||
|
@ -40,7 +40,7 @@ class CrossEntropyLoss(object):
|
|||||||
storage = get_event_storage()
|
storage = get_event_storage()
|
||||||
storage.put_scalar("cls_accuracy", ret[0])
|
storage.put_scalar("cls_accuracy", ret[0])
|
||||||
|
|
||||||
def __call__(self, pred_class_logits, gt_classes):
|
def __call__(self, pred_class_logits, _, gt_classes):
|
||||||
"""
|
"""
|
||||||
Compute the softmax cross entropy loss for box classification.
|
Compute the softmax cross entropy loss for box classification.
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -6,8 +6,11 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["TripletLoss", "CircleLoss"]
|
||||||
|
|
||||||
def normalize(x, axis=-1):
|
def normalize(x, axis=-1):
|
||||||
"""Normalizing to unit length along the specified dimension.
|
"""Normalizing to unit length along the specified dimension.
|
||||||
Args:
|
Args:
|
||||||
@ -101,8 +104,10 @@ def weighted_example_mining(dist_mat, is_pos, is_neg):
|
|||||||
assert len(dist_mat.size()) == 2
|
assert len(dist_mat.size()) == 2
|
||||||
assert dist_mat.size(0) == dist_mat.size(1)
|
assert dist_mat.size(0) == dist_mat.size(1)
|
||||||
|
|
||||||
dist_ap = dist_mat * is_pos.float()
|
is_pos = is_pos.float()
|
||||||
dist_an = dist_mat * is_neg.float()
|
is_neg = is_neg.float()
|
||||||
|
dist_ap = dist_mat * is_pos
|
||||||
|
dist_an = dist_mat * is_neg
|
||||||
|
|
||||||
weights_ap = softmax_weights(dist_ap, is_pos)
|
weights_ap = softmax_weights(dist_ap, is_pos)
|
||||||
weights_an = softmax_weights(-dist_an, is_neg)
|
weights_an = softmax_weights(-dist_an, is_neg)
|
||||||
@ -130,7 +135,7 @@ class TripletLoss(object):
|
|||||||
else:
|
else:
|
||||||
self.ranking_loss = nn.SoftMarginLoss()
|
self.ranking_loss = nn.SoftMarginLoss()
|
||||||
|
|
||||||
def __call__(self, global_features, targets):
|
def __call__(self, _, global_features, targets):
|
||||||
if self._normalize_feature:
|
if self._normalize_feature:
|
||||||
global_features = normalize(global_features, axis=-1)
|
global_features = normalize(global_features, axis=-1)
|
||||||
|
|
||||||
@ -157,3 +162,38 @@ class TripletLoss(object):
|
|||||||
return {
|
return {
|
||||||
"loss_triplet": loss * self._scale,
|
"loss_triplet": loss * self._scale,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CircleLoss(object):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
self._scale = cfg.MODEL.LOSSES.SCALE_TRI
|
||||||
|
|
||||||
|
self.m = 0.25
|
||||||
|
self.s = 128
|
||||||
|
|
||||||
|
def __call__(self, _, global_features, targets):
|
||||||
|
global_features = normalize(global_features, axis=-1)
|
||||||
|
|
||||||
|
sim_mat = torch.matmul(global_features, global_features.t())
|
||||||
|
|
||||||
|
N = sim_mat.size(0)
|
||||||
|
is_pos = targets.expand(N, N).eq(targets.expand(N, N).t()).float() - torch.eye(N).to(sim_mat.device)
|
||||||
|
is_pos = is_pos.bool()
|
||||||
|
is_neg = targets.expand(N, N).ne(targets.expand(N, N).t())
|
||||||
|
|
||||||
|
s_p = sim_mat[is_pos].contiguous().view(N, -1)
|
||||||
|
s_n = sim_mat[is_neg].contiguous().view(N, -1)
|
||||||
|
|
||||||
|
alpha_p = F.relu(-s_p.detach() + 1 + self.m)
|
||||||
|
alpha_n = F.relu(s_n.detach() + self.m)
|
||||||
|
delta_p = 1 - self.m
|
||||||
|
delta_n = self.m
|
||||||
|
|
||||||
|
logit_p = - self.s * alpha_p * (s_p - delta_p)
|
||||||
|
logit_n = self.s * alpha_n * (s_n - delta_n)
|
||||||
|
|
||||||
|
loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"loss_circle": loss * self._scale,
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user