mirror of https://github.com/JDAI-CV/fast-reid.git
update freeze layer
update preciseBN update circle loss with metric learning and cross entropy loss form update loss call methodspull/43/head
parent
6a8961ce48
commit
4d2fa28dbb
|
@ -173,6 +173,11 @@ _C.TEST = CN()
|
|||
_C.TEST.EVAL_PERIOD = 50
|
||||
_C.TEST.IMS_PER_BATCH = 128
|
||||
|
||||
# Precise BN
|
||||
_C.TEST.PRECISE_BN = CN()
|
||||
_C.TEST.PRECISE_BN.ENABLED = False
|
||||
_C.TEST.PRECISE_BN.DATASET = 'Market1501'
|
||||
_C.TEST.PRECISE_BN.NUM_ITER = 300
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# Misc options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
|
|
|
@ -257,17 +257,21 @@ class DefaultTrainer(SimpleTrainer):
|
|||
ret = [
|
||||
hooks.IterationTimer(),
|
||||
hooks.LRScheduler(self.optimizer, self.scheduler),
|
||||
# hooks.PreciseBN(
|
||||
# # Run at the same freq as (but before) evaluation.
|
||||
# cfg.TEST.EVAL_PERIOD,
|
||||
# self.model,
|
||||
# # Build a new data loader to not affect training
|
||||
# self.build_train_loader(cfg),
|
||||
# cfg.TEST.PRECISE_BN.NUM_ITER,
|
||||
# )
|
||||
# if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
|
||||
# else None,
|
||||
hooks.FreezeLayer(self.model, cfg.MODEL.OPEN_LAYERS, cfg.SOLVER.FREEZE_ITERS)
|
||||
hooks.PreciseBN(
|
||||
# Run at the same freq as (but before) evaluation.
|
||||
cfg.TEST.EVAL_PERIOD,
|
||||
self.model,
|
||||
# Build a new data loader to not affect training
|
||||
self.build_train_loader(cfg),
|
||||
cfg.TEST.PRECISE_BN.NUM_ITER,
|
||||
)
|
||||
if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model)
|
||||
else None,
|
||||
hooks.FreezeLayer(
|
||||
self.model,
|
||||
cfg.MODEL.OPEN_LAYERS,
|
||||
cfg.SOLVER.FREEZE_ITERS)
|
||||
if cfg.MODEL.OPEN_LAYERS != '' and cfg.SOLVER.FREEZE_ITERS > 0 else None,
|
||||
]
|
||||
|
||||
# Do PreciseBN before checkpointer, because it updates the model and need to
|
||||
|
|
|
@ -430,9 +430,7 @@ class FreezeLayer(HookBase):
|
|||
self.model = model
|
||||
|
||||
self.freeze_iters = freeze_iters
|
||||
if open_layer_names == '':
|
||||
self.freeze_iters = -1
|
||||
elif isinstance(open_layer_names, str):
|
||||
if isinstance(open_layer_names, str):
|
||||
open_layer_names = [open_layer_names]
|
||||
|
||||
self.open_layer_names = open_layer_names
|
||||
|
|
|
@ -38,7 +38,8 @@ class BNneckHead(nn.Module):
|
|||
return bn_feat
|
||||
# training
|
||||
pred_class_logits = self.classifier(bn_feat)
|
||||
return pred_class_logits, global_feat
|
||||
# return pred_class_logits, global_feat
|
||||
return pred_class_logits, bn_feat
|
||||
|
||||
@classmethod
|
||||
def losses(cls, cfg, pred_class_logits, global_features, gt_classes, prefix='') -> dict:
|
||||
|
|
|
@ -48,20 +48,19 @@ class CircleHead(nn.Module):
|
|||
if not self.training:
|
||||
return bn_feat
|
||||
|
||||
cos_sim = F.linear(F.normalize(bn_feat), F.normalize(self.weight))
|
||||
alpha_p = F.relu(-cos_sim + 1 + self._m)
|
||||
alpha_n = F.relu(cos_sim + self._m)
|
||||
margin_p = 1 - self._m
|
||||
margin_n = self._m
|
||||
sim_mat = F.linear(F.normalize(bn_feat), F.normalize(self.weight))
|
||||
alpha_p = F.relu(-sim_mat.detach() + 1 + self._m)
|
||||
alpha_n = F.relu(sim_mat.detach() + self._m)
|
||||
delta_p = 1 - self._m
|
||||
delta_n = self._m
|
||||
|
||||
sp = alpha_p * (cos_sim - margin_p)
|
||||
sn = alpha_n * (cos_sim - margin_n)
|
||||
s_p = self._s * alpha_p * (sim_mat - delta_p)
|
||||
s_n = self._s * alpha_n * (sim_mat - delta_n)
|
||||
|
||||
one_hot = torch.zeros(cos_sim.size()).to(targets.device)
|
||||
one_hot = torch.zeros(sim_mat.size()).to(targets.device)
|
||||
one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
|
||||
|
||||
pred_class_logits = one_hot * sp + ((1.0 - one_hot) * sn)
|
||||
pred_class_logits *= self._s
|
||||
pred_class_logits = one_hot * s_p + (1.0 - one_hot) * s_n
|
||||
|
||||
return pred_class_logits, global_feat
|
||||
|
||||
|
|
|
@ -7,9 +7,9 @@
|
|||
from torch import nn
|
||||
|
||||
from .build import REID_HEADS_REGISTRY
|
||||
from ..losses import CrossEntropyLoss, TripletLoss
|
||||
from ..model_utils import weights_init_classifier, weights_init_kaiming
|
||||
from ...layers import bn_no_bias, Flatten
|
||||
from .. import losses as Loss
|
||||
from ..model_utils import weights_init_classifier
|
||||
from ...layers import Flatten
|
||||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
|
@ -41,11 +41,8 @@ class LinearHead(nn.Module):
|
|||
@classmethod
|
||||
def losses(cls, cfg, pred_class_logits, global_features, gt_classes, prefix='') -> dict:
|
||||
loss_dict = {}
|
||||
if "CrossEntropyLoss" in cfg.MODEL.LOSSES.NAME and pred_class_logits is not None:
|
||||
loss = CrossEntropyLoss(cfg)(pred_class_logits, gt_classes)
|
||||
loss_dict.update(loss)
|
||||
if "TripletLoss" in cfg.MODEL.LOSSES.NAME and global_features is not None:
|
||||
loss = TripletLoss(cfg)(global_features, gt_classes)
|
||||
for loss_name in cfg.MODEL.LOSSES.NAME:
|
||||
loss = getattr(Loss, loss_name)(cfg)(pred_class_logits, global_features, gt_classes)
|
||||
loss_dict.update(loss)
|
||||
# rename
|
||||
name_loss_dict = {}
|
||||
|
|
|
@ -5,4 +5,4 @@
|
|||
"""
|
||||
|
||||
from .cross_entroy_loss import CrossEntropyLoss
|
||||
from .metric_loss import TripletLoss
|
||||
from .metric_loss import *
|
||||
|
|
|
@ -40,7 +40,7 @@ class CrossEntropyLoss(object):
|
|||
storage = get_event_storage()
|
||||
storage.put_scalar("cls_accuracy", ret[0])
|
||||
|
||||
def __call__(self, pred_class_logits, gt_classes):
|
||||
def __call__(self, pred_class_logits, _, gt_classes):
|
||||
"""
|
||||
Compute the softmax cross entropy loss for box classification.
|
||||
Returns:
|
||||
|
|
|
@ -6,8 +6,11 @@
|
|||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
__all__ = ["TripletLoss", "CircleLoss"]
|
||||
|
||||
def normalize(x, axis=-1):
|
||||
"""Normalizing to unit length along the specified dimension.
|
||||
Args:
|
||||
|
@ -101,8 +104,10 @@ def weighted_example_mining(dist_mat, is_pos, is_neg):
|
|||
assert len(dist_mat.size()) == 2
|
||||
assert dist_mat.size(0) == dist_mat.size(1)
|
||||
|
||||
dist_ap = dist_mat * is_pos.float()
|
||||
dist_an = dist_mat * is_neg.float()
|
||||
is_pos = is_pos.float()
|
||||
is_neg = is_neg.float()
|
||||
dist_ap = dist_mat * is_pos
|
||||
dist_an = dist_mat * is_neg
|
||||
|
||||
weights_ap = softmax_weights(dist_ap, is_pos)
|
||||
weights_an = softmax_weights(-dist_an, is_neg)
|
||||
|
@ -130,7 +135,7 @@ class TripletLoss(object):
|
|||
else:
|
||||
self.ranking_loss = nn.SoftMarginLoss()
|
||||
|
||||
def __call__(self, global_features, targets):
|
||||
def __call__(self, _, global_features, targets):
|
||||
if self._normalize_feature:
|
||||
global_features = normalize(global_features, axis=-1)
|
||||
|
||||
|
@ -157,3 +162,38 @@ class TripletLoss(object):
|
|||
return {
|
||||
"loss_triplet": loss * self._scale,
|
||||
}
|
||||
|
||||
|
||||
class CircleLoss(object):
|
||||
def __init__(self, cfg):
|
||||
self._scale = cfg.MODEL.LOSSES.SCALE_TRI
|
||||
|
||||
self.m = 0.25
|
||||
self.s = 128
|
||||
|
||||
def __call__(self, _, global_features, targets):
|
||||
global_features = normalize(global_features, axis=-1)
|
||||
|
||||
sim_mat = torch.matmul(global_features, global_features.t())
|
||||
|
||||
N = sim_mat.size(0)
|
||||
is_pos = targets.expand(N, N).eq(targets.expand(N, N).t()).float() - torch.eye(N).to(sim_mat.device)
|
||||
is_pos = is_pos.bool()
|
||||
is_neg = targets.expand(N, N).ne(targets.expand(N, N).t())
|
||||
|
||||
s_p = sim_mat[is_pos].contiguous().view(N, -1)
|
||||
s_n = sim_mat[is_neg].contiguous().view(N, -1)
|
||||
|
||||
alpha_p = F.relu(-s_p.detach() + 1 + self.m)
|
||||
alpha_n = F.relu(s_n.detach() + self.m)
|
||||
delta_p = 1 - self.m
|
||||
delta_n = self.m
|
||||
|
||||
logit_p = - self.s * alpha_p * (s_p - delta_p)
|
||||
logit_n = self.s * alpha_n * (s_n - delta_n)
|
||||
|
||||
loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
|
||||
|
||||
return {
|
||||
"loss_circle": loss * self._scale,
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue