update freeze layer

update preciseBN
update circle loss with metric learning and cross entropy loss form
update loss call methods
pull/43/head
liaoxingyu 2020-04-06 23:34:27 +08:00
parent 6a8961ce48
commit 4d2fa28dbb
9 changed files with 82 additions and 38 deletions

View File

@ -173,6 +173,11 @@ _C.TEST = CN()
_C.TEST.EVAL_PERIOD = 50
_C.TEST.IMS_PER_BATCH = 128
# Precise BN
_C.TEST.PRECISE_BN = CN()
_C.TEST.PRECISE_BN.ENABLED = False
_C.TEST.PRECISE_BN.DATASET = 'Market1501'
_C.TEST.PRECISE_BN.NUM_ITER = 300
# ---------------------------------------------------------------------------- #
# Misc options
# ---------------------------------------------------------------------------- #

View File

@ -257,17 +257,21 @@ class DefaultTrainer(SimpleTrainer):
ret = [
hooks.IterationTimer(),
hooks.LRScheduler(self.optimizer, self.scheduler),
# hooks.PreciseBN(
# # Run at the same freq as (but before) evaluation.
# cfg.TEST.EVAL_PERIOD,
# self.model,
# # Build a new data loader to not affect training
# self.build_train_loader(cfg),
# cfg.TEST.PRECISE_BN.NUM_ITER,
# )
# if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
# else None,
hooks.FreezeLayer(self.model, cfg.MODEL.OPEN_LAYERS, cfg.SOLVER.FREEZE_ITERS)
hooks.PreciseBN(
# Run at the same freq as (but before) evaluation.
cfg.TEST.EVAL_PERIOD,
self.model,
# Build a new data loader to not affect training
self.build_train_loader(cfg),
cfg.TEST.PRECISE_BN.NUM_ITER,
)
if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model)
else None,
hooks.FreezeLayer(
self.model,
cfg.MODEL.OPEN_LAYERS,
cfg.SOLVER.FREEZE_ITERS)
if cfg.MODEL.OPEN_LAYERS != '' and cfg.SOLVER.FREEZE_ITERS > 0 else None,
]
# Do PreciseBN before checkpointer, because it updates the model and need to

View File

@ -430,9 +430,7 @@ class FreezeLayer(HookBase):
self.model = model
self.freeze_iters = freeze_iters
if open_layer_names == '':
self.freeze_iters = -1
elif isinstance(open_layer_names, str):
if isinstance(open_layer_names, str):
open_layer_names = [open_layer_names]
self.open_layer_names = open_layer_names

View File

@ -38,7 +38,8 @@ class BNneckHead(nn.Module):
return bn_feat
# training
pred_class_logits = self.classifier(bn_feat)
return pred_class_logits, global_feat
# return pred_class_logits, global_feat
return pred_class_logits, bn_feat
@classmethod
def losses(cls, cfg, pred_class_logits, global_features, gt_classes, prefix='') -> dict:

View File

@ -48,20 +48,19 @@ class CircleHead(nn.Module):
if not self.training:
return bn_feat
cos_sim = F.linear(F.normalize(bn_feat), F.normalize(self.weight))
alpha_p = F.relu(-cos_sim + 1 + self._m)
alpha_n = F.relu(cos_sim + self._m)
margin_p = 1 - self._m
margin_n = self._m
sim_mat = F.linear(F.normalize(bn_feat), F.normalize(self.weight))
alpha_p = F.relu(-sim_mat.detach() + 1 + self._m)
alpha_n = F.relu(sim_mat.detach() + self._m)
delta_p = 1 - self._m
delta_n = self._m
sp = alpha_p * (cos_sim - margin_p)
sn = alpha_n * (cos_sim - margin_n)
s_p = self._s * alpha_p * (sim_mat - delta_p)
s_n = self._s * alpha_n * (sim_mat - delta_n)
one_hot = torch.zeros(cos_sim.size()).to(targets.device)
one_hot = torch.zeros(sim_mat.size()).to(targets.device)
one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
pred_class_logits = one_hot * sp + ((1.0 - one_hot) * sn)
pred_class_logits *= self._s
pred_class_logits = one_hot * s_p + (1.0 - one_hot) * s_n
return pred_class_logits, global_feat

View File

@ -7,9 +7,9 @@
from torch import nn
from .build import REID_HEADS_REGISTRY
from ..losses import CrossEntropyLoss, TripletLoss
from ..model_utils import weights_init_classifier, weights_init_kaiming
from ...layers import bn_no_bias, Flatten
from .. import losses as Loss
from ..model_utils import weights_init_classifier
from ...layers import Flatten
@REID_HEADS_REGISTRY.register()
@ -41,11 +41,8 @@ class LinearHead(nn.Module):
@classmethod
def losses(cls, cfg, pred_class_logits, global_features, gt_classes, prefix='') -> dict:
loss_dict = {}
if "CrossEntropyLoss" in cfg.MODEL.LOSSES.NAME and pred_class_logits is not None:
loss = CrossEntropyLoss(cfg)(pred_class_logits, gt_classes)
loss_dict.update(loss)
if "TripletLoss" in cfg.MODEL.LOSSES.NAME and global_features is not None:
loss = TripletLoss(cfg)(global_features, gt_classes)
for loss_name in cfg.MODEL.LOSSES.NAME:
loss = getattr(Loss, loss_name)(cfg)(pred_class_logits, global_features, gt_classes)
loss_dict.update(loss)
# rename
name_loss_dict = {}

View File

@ -5,4 +5,4 @@
"""
from .cross_entroy_loss import CrossEntropyLoss
from .metric_loss import TripletLoss
from .metric_loss import *

View File

@ -40,7 +40,7 @@ class CrossEntropyLoss(object):
storage = get_event_storage()
storage.put_scalar("cls_accuracy", ret[0])
def __call__(self, pred_class_logits, gt_classes):
def __call__(self, pred_class_logits, _, gt_classes):
"""
Compute the softmax cross entropy loss for box classification.
Returns:

View File

@ -6,8 +6,11 @@
import torch
from torch import nn
import torch.nn.functional as F
__all__ = ["TripletLoss", "CircleLoss"]
def normalize(x, axis=-1):
"""Normalizing to unit length along the specified dimension.
Args:
@ -101,8 +104,10 @@ def weighted_example_mining(dist_mat, is_pos, is_neg):
assert len(dist_mat.size()) == 2
assert dist_mat.size(0) == dist_mat.size(1)
dist_ap = dist_mat * is_pos.float()
dist_an = dist_mat * is_neg.float()
is_pos = is_pos.float()
is_neg = is_neg.float()
dist_ap = dist_mat * is_pos
dist_an = dist_mat * is_neg
weights_ap = softmax_weights(dist_ap, is_pos)
weights_an = softmax_weights(-dist_an, is_neg)
@ -130,7 +135,7 @@ class TripletLoss(object):
else:
self.ranking_loss = nn.SoftMarginLoss()
def __call__(self, global_features, targets):
def __call__(self, _, global_features, targets):
if self._normalize_feature:
global_features = normalize(global_features, axis=-1)
@ -157,3 +162,38 @@ class TripletLoss(object):
return {
"loss_triplet": loss * self._scale,
}
class CircleLoss(object):
def __init__(self, cfg):
self._scale = cfg.MODEL.LOSSES.SCALE_TRI
self.m = 0.25
self.s = 128
def __call__(self, _, global_features, targets):
global_features = normalize(global_features, axis=-1)
sim_mat = torch.matmul(global_features, global_features.t())
N = sim_mat.size(0)
is_pos = targets.expand(N, N).eq(targets.expand(N, N).t()).float() - torch.eye(N).to(sim_mat.device)
is_pos = is_pos.bool()
is_neg = targets.expand(N, N).ne(targets.expand(N, N).t())
s_p = sim_mat[is_pos].contiguous().view(N, -1)
s_n = sim_mat[is_neg].contiguous().view(N, -1)
alpha_p = F.relu(-s_p.detach() + 1 + self.m)
alpha_n = F.relu(s_n.detach() + self.m)
delta_p = 1 - self.m
delta_n = self.m
logit_p = - self.s * alpha_p * (s_p - delta_p)
logit_n = self.s * alpha_n * (s_n - delta_n)
loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
return {
"loss_circle": loss * self._scale,
}