mirror of https://github.com/JDAI-CV/fast-reid.git
update focal loss
update dataset info display update seperate lr update adaptive label smooth regularizationpull/43/head
parent
9cf222e093
commit
be9faa5605
|
@ -48,6 +48,16 @@ _C.MODEL.HEADS.NAME = "BNneckHead"
|
|||
_C.MODEL.HEADS.POOL_LAYER = 'avgpool'
|
||||
_C.MODEL.HEADS.NUM_CLASSES = 751
|
||||
|
||||
# Arcface head
|
||||
_C.MODEL.HEADS.ARCFACE = CN()
|
||||
_C.MODEL.HEADS.ARCFACE.MARGIN = 0.5
|
||||
_C.MODEL.HEADS.ARCFACE.SCALE = 30.0
|
||||
|
||||
# Circle Loss
|
||||
_C.MODEL.HEADS.CIRCLE = CN()
|
||||
_C.MODEL.HEADS.CIRCLE.MARGIN = 0.15
|
||||
_C.MODEL.HEADS.CIRCLE.SCALE = 128.0
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# REID LOSSES options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
|
@ -55,16 +65,26 @@ _C.MODEL.LOSSES = CN()
|
|||
_C.MODEL.LOSSES.NAME = ("CrossEntropyLoss",)
|
||||
|
||||
# Cross Entropy Loss options
|
||||
_C.MODEL.LOSSES.SMOOTH_ON = False
|
||||
_C.MODEL.LOSSES.EPSILON = 0.1
|
||||
_C.MODEL.LOSSES.SCALE_CE = 1.0
|
||||
_C.MODEL.LOSSES.CE = CN()
|
||||
# if epsilon == 0, it means no label smooth regularization,
|
||||
# if epsilon == -1, it means adaptive label smooth regularization
|
||||
_C.MODEL.LOSSES.CE.EPSILON = 0.0
|
||||
_C.MODEL.LOSSES.CE.ALPHA = 0.2
|
||||
_C.MODEL.LOSSES.CE.SCALE = 1.0
|
||||
|
||||
# Triplet Loss options
|
||||
_C.MODEL.LOSSES.MARGIN = 0.3
|
||||
_C.MODEL.LOSSES.NORM_FEAT = False
|
||||
_C.MODEL.LOSSES.HARD_MINING = True
|
||||
_C.MODEL.LOSSES.USE_COSINE_DIST = True
|
||||
_C.MODEL.LOSSES.SCALE_TRI = 1.0
|
||||
_C.MODEL.LOSSES.TRI = CN()
|
||||
_C.MODEL.LOSSES.TRI.MARGIN = 0.3
|
||||
_C.MODEL.LOSSES.TRI.NORM_FEAT = False
|
||||
_C.MODEL.LOSSES.TRI.HARD_MINING = True
|
||||
_C.MODEL.LOSSES.TRI.USE_COSINE_DIST = True
|
||||
_C.MODEL.LOSSES.TRI.SCALE = 1.0
|
||||
|
||||
# Focal Loss options
|
||||
_C.MODEL.LOSSES.FL = CN()
|
||||
_C.MODEL.LOSSES.FL.ALPHA = 0.25
|
||||
_C.MODEL.LOSSES.FL.GAMMA = 2
|
||||
_C.MODEL.LOSSES.FL.SCALE = 1.0
|
||||
|
||||
# Path (possibly with schema like catalog:// or detectron2://) to a checkpoint file
|
||||
# to be loaded to the model. You can find available models in the model zoo.
|
||||
|
@ -136,7 +156,8 @@ _C.SOLVER.OPT = "Adam"
|
|||
_C.SOLVER.MAX_ITER = 40000
|
||||
|
||||
_C.SOLVER.BASE_LR = 3e-4
|
||||
_C.SOLVER.BIAS_LR_FACTOR = 1
|
||||
_C.SOLVER.BIAS_LR_FACTOR = 1.
|
||||
_C.SOLVER.HEADS_LR_FACTOR = 1.
|
||||
|
||||
_C.SOLVER.MOMENTUM = 0.9
|
||||
|
||||
|
@ -144,11 +165,11 @@ _C.SOLVER.WEIGHT_DECAY = 0.0005
|
|||
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.
|
||||
|
||||
_C.SOLVER.SCHED = "warmup"
|
||||
# warmup config
|
||||
# Warmup config
|
||||
_C.SOLVER.GAMMA = 0.1
|
||||
_C.SOLVER.STEPS = (30, 55)
|
||||
|
||||
# cosine annealing
|
||||
# Cosine annealing
|
||||
_C.SOLVER.DELAY_ITERS = 100
|
||||
_C.SOLVER.COS_ANNEAL_ITERS = 100
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ def build_reid_train_loader(cfg):
|
|||
for d in cfg.DATASETS.NAMES:
|
||||
logger.info('prepare training set {}'.format(d))
|
||||
dataset = DATASET_REGISTRY.get(d)()
|
||||
dataset.show_summary()
|
||||
train_items.extend(dataset.train)
|
||||
|
||||
train_set = CommDataset(train_items, train_transforms, relabel=True)
|
||||
|
@ -52,6 +53,7 @@ def build_reid_test_loader(cfg, dataset_name):
|
|||
logger = logging.getLogger(__name__)
|
||||
logger.info('prepare test set {}'.format(dataset_name))
|
||||
dataset = DATASET_REGISTRY.get(dataset_name)()
|
||||
dataset.show_summary()
|
||||
test_items = dataset.query + dataset.gallery
|
||||
|
||||
test_set = CommDataset(test_items, test_transforms, relabel=False)
|
||||
|
|
|
@ -9,6 +9,7 @@ import os
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
|
@ -197,19 +198,19 @@ class ImageDataset(Dataset):
|
|||
super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def show_summary(self):
|
||||
logger = logging.getLogger(__name__)
|
||||
num_train_pids, num_train_cams = self.parse_data(self.train)
|
||||
num_query_pids, num_query_cams = self.parse_data(self.query)
|
||||
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
|
||||
|
||||
print('=> Loaded {}'.format(self.__class__.__name__))
|
||||
print(' ----------------------------------------')
|
||||
print(' subset | # ids | # images | # cameras')
|
||||
print(' ----------------------------------------')
|
||||
print(' train | {:5d} | {:8d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams))
|
||||
print(' query | {:5d} | {:8d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams))
|
||||
print(' gallery | {:5d} | {:8d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams))
|
||||
print(' ----------------------------------------')
|
||||
|
||||
logger.info('=> Loaded {}'.format(self.__class__.__name__))
|
||||
logger.info(' ----------------------------------------')
|
||||
logger.info(' subset | # ids | # images | # cameras')
|
||||
logger.info(' ----------------------------------------')
|
||||
logger.info(' train | {:5d} | {:8d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams))
|
||||
logger.info(' query | {:5d} | {:8d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams))
|
||||
logger.info(' gallery | {:5d} | {:8d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams))
|
||||
logger.info(' ----------------------------------------')
|
||||
|
||||
# class VideoDataset(Dataset):
|
||||
# """A base class representing VideoDataset.
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch.nn import Parameter
|
|||
|
||||
from .build import REID_HEADS_REGISTRY
|
||||
from .linear_head import LinearHead
|
||||
from ..losses.loss_utils import one_hot
|
||||
from ..model_utils import weights_init_kaiming
|
||||
from ...layers import NoBiasBatchNorm1d, Flatten
|
||||
|
||||
|
@ -32,9 +33,8 @@ class ArcfaceHead(nn.Module):
|
|||
self.bnneck.apply(weights_init_kaiming)
|
||||
|
||||
# classifier
|
||||
# self.adaptive_s = False
|
||||
self._s = 30.0
|
||||
self._m = 0.50
|
||||
self._s = cfg.MODEL.HEADS.ARCFACE.SCALE
|
||||
self._m = cfg.MODEL.HEADS.ARCFACE.MARGIN
|
||||
|
||||
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
|
||||
self.reset_parameters()
|
||||
|
@ -52,24 +52,13 @@ class ArcfaceHead(nn.Module):
|
|||
cosine = F.linear(F.normalize(bn_feat), F.normalize(self.weight))
|
||||
|
||||
# add margin
|
||||
theta = torch.acos(torch.clamp(cosine, -1.0+1e-7, 1.0-1e-7))
|
||||
theta = torch.acos(torch.clamp(cosine, -1.0 + 1e-7, 1.0 - 1e-7))
|
||||
|
||||
phi = torch.cos(theta + self._m)
|
||||
|
||||
# --------------------------- convert label to one-hot ---------------------------
|
||||
one_hot = torch.zeros_like(cosine)
|
||||
one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
|
||||
|
||||
# if self.adaptive_s:
|
||||
# with torch.no_grad():
|
||||
# B_avg = torch.where(one_hot < 1, torch.exp(self._s * cosine), torch.zeros_like(cosine))
|
||||
# B_avg = torch.sum(B_avg) / cosine.size(0)
|
||||
# theta_med = torch.median(theta[one_hot == 1])
|
||||
# s = torch.log(B_avg) / torch.cos(torch.min(math.pi/4 * torch.ones_like(theta_med)))
|
||||
# else:
|
||||
# s = self._s
|
||||
# you can use torch.where if your torch.__version__ is 0.4
|
||||
pred_class_logits = one_hot * phi + (1.0 - one_hot) * cosine
|
||||
targets = one_hot(targets, self._num_classes)
|
||||
pred_class_logits = targets * phi + (1.0 - targets) * cosine
|
||||
|
||||
# logits re-scale
|
||||
pred_class_logits *= self._s
|
||||
|
|
|
@ -33,8 +33,8 @@ class CircleHead(nn.Module):
|
|||
self.bnneck.apply(weights_init_kaiming)
|
||||
|
||||
# classifier
|
||||
self._s = 128.0
|
||||
self._m = 0.15
|
||||
self._s = cfg.MODEL.HEADS.CIRCLE.SCALE
|
||||
self._m = cfg.MODEL.HEADS.CIRCLE.MARGIN
|
||||
|
||||
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
|
||||
self.reset_parameters()
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
|
||||
from ...utils.events import get_event_storage
|
||||
from .loss_utils import one_hot
|
||||
|
||||
|
||||
class CrossEntropyLoss(object):
|
||||
|
@ -16,9 +17,9 @@ class CrossEntropyLoss(object):
|
|||
|
||||
def __init__(self, cfg):
|
||||
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
self._epsilon = cfg.MODEL.LOSSES.EPSILON
|
||||
self._smooth_on = cfg.MODEL.LOSSES.SMOOTH_ON
|
||||
self._scale = cfg.MODEL.LOSSES.SCALE_CE
|
||||
self._eps = cfg.MODEL.LOSSES.CE.EPSILON
|
||||
self._alpha = cfg.MODEL.LOSSES.CE.ALPHA
|
||||
self._scale = cfg.MODEL.LOSSES.CE.SCALE
|
||||
|
||||
self._topk = (1,)
|
||||
|
||||
|
@ -47,14 +48,20 @@ class CrossEntropyLoss(object):
|
|||
scalar Tensor
|
||||
"""
|
||||
self._log_accuracy(pred_class_logits, gt_classes)
|
||||
if self._smooth_on:
|
||||
log_probs = F.log_softmax(pred_class_logits, dim=1)
|
||||
targets = torch.zeros(log_probs.size()).scatter_(1, gt_classes.unsqueeze(1).data.cpu(), 1)
|
||||
targets = targets.to(pred_class_logits.device)
|
||||
targets = (1 - self._epsilon) * targets + self._epsilon / self._num_classes
|
||||
loss = (-targets * log_probs).mean(0).sum()
|
||||
if self._eps >= 0:
|
||||
smooth_param = self._eps
|
||||
else:
|
||||
loss = F.cross_entropy(pred_class_logits, gt_classes, reduction="mean")
|
||||
# adaptive lsr
|
||||
soft_label = F.softmax(pred_class_logits, dim=1)
|
||||
smooth_param = self._alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1)
|
||||
|
||||
log_probs = F.log_softmax(pred_class_logits, dim=1)
|
||||
with torch.no_grad():
|
||||
targets = torch.ones_like(log_probs)
|
||||
targets *= smooth_param / (self._num_classes - 1)
|
||||
targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param))
|
||||
|
||||
loss = (-targets * log_probs).mean(0).sum()
|
||||
return {
|
||||
"loss_cls": loss * self._scale,
|
||||
}
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .loss_utils import one_hot
|
||||
|
||||
|
||||
# based on:
|
||||
# https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py
|
||||
|
||||
def focal_loss(
|
||||
input: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
alpha: float,
|
||||
gamma: float = 2.0,
|
||||
reduction: str = 'mean', ) -> torch.Tensor:
|
||||
r"""Function that computes Focal loss.
|
||||
See :class:`fastreid.modeling.losses.FocalLoss` for details.
|
||||
"""
|
||||
if not torch.is_tensor(input):
|
||||
raise TypeError("Input type is not a torch.Tensor. Got {}"
|
||||
.format(type(input)))
|
||||
|
||||
if not len(input.shape) >= 2:
|
||||
raise ValueError("Invalid input shape, we expect BxCx*. Got: {}"
|
||||
.format(input.shape))
|
||||
|
||||
if input.size(0) != target.size(0):
|
||||
raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
|
||||
.format(input.size(0), target.size(0)))
|
||||
|
||||
n = input.size(0)
|
||||
out_size = (n,) + input.size()[2:]
|
||||
if target.size()[1:] != input.size()[2:]:
|
||||
raise ValueError('Expected target size {}, got {}'.format(
|
||||
out_size, target.size()))
|
||||
|
||||
if not input.device == target.device:
|
||||
raise ValueError(
|
||||
"input and target must be in the same device. Got: {}".format(
|
||||
input.device, target.device))
|
||||
|
||||
# compute softmax over the classes axis
|
||||
input_soft = F.softmax(input, dim=1)
|
||||
|
||||
# create the labels one hot tensor
|
||||
target_one_hot = one_hot(
|
||||
target, num_classes=input.shape[1],
|
||||
dtype=input.dtype)
|
||||
|
||||
# compute the actual focal loss
|
||||
weight = torch.pow(-input_soft + 1., gamma)
|
||||
|
||||
focal = -alpha * weight * torch.log(input_soft)
|
||||
loss_tmp = torch.sum(target_one_hot * focal, dim=1)
|
||||
|
||||
if reduction == 'none':
|
||||
loss = loss_tmp
|
||||
elif reduction == 'mean':
|
||||
loss = torch.mean(loss_tmp)
|
||||
elif reduction == 'sum':
|
||||
loss = torch.sum(loss_tmp)
|
||||
else:
|
||||
raise NotImplementedError("Invalid reduction mode: {}"
|
||||
.format(reduction))
|
||||
return loss
|
||||
|
||||
|
||||
class FocalLoss(object):
|
||||
r"""Criterion that computes Focal loss.
|
||||
According to [1], the Focal loss is computed as follows:
|
||||
.. math::
|
||||
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
|
||||
where:
|
||||
- :math:`p_t` is the model's estimated probability for each class.
|
||||
Arguments:
|
||||
alpha (float): Weighting factor :math:`\alpha \in [0, 1]`.
|
||||
gamma (float): Focusing parameter :math:`\gamma >= 0`.
|
||||
reduction (str, optional): Specifies the reduction to apply to the
|
||||
output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied,
|
||||
‘mean’: the sum of the output will be divided by the number of elements
|
||||
in the output, ‘sum’: the output will be summed. Default: ‘none’.
|
||||
Shape:
|
||||
- Input: :math:`(N, C, *)` where C = number of classes.
|
||||
- Target: :math:`(N, *)` where each value is
|
||||
:math:`0 ≤ targets[i] ≤ C−1`.
|
||||
Examples:
|
||||
>>> N = 5 # num_classes
|
||||
>>> loss = FocalLoss(cfg)
|
||||
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
|
||||
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
|
||||
>>> output = loss(input, target)
|
||||
>>> output.backward()
|
||||
References:
|
||||
[1] https://arxiv.org/abs/1708.02002
|
||||
"""
|
||||
|
||||
# def __init__(self, alpha: float, gamma: float = 2.0,
|
||||
# reduction: str = 'none') -> None:
|
||||
def __init__(self, cfg):
|
||||
self._alpha: float = cfg.MODEL.LOSSES.FL.ALPHA
|
||||
self._gamma: float = cfg.MODEL.LOSSES.FL.GAMMA
|
||||
self._scale: float = cfg.MODEL.LOSSES.FL.SCALE
|
||||
|
||||
def __call__(self, pred_class_logits: torch.Tensor, _, gt_classes: torch.Tensor) -> dict:
|
||||
loss = focal_loss(pred_class_logits, gt_classes, self._alpha, self._gamma)
|
||||
return {
|
||||
'loss_focal': loss * self._scale,
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: liaoxingyu5@jd.com
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
# based on:
|
||||
# https://github.com/kornia/kornia/blob/master/kornia/utils/one_hot.py
|
||||
|
||||
|
||||
def one_hot(labels: torch.Tensor,
|
||||
num_classes: int,
|
||||
dtype: Optional[torch.dtype] = None,) -> torch.Tensor:
|
||||
# eps: Optional[float] = 1e-6) -> torch.Tensor:
|
||||
r"""Converts an integer label x-D tensor to a one-hot (x+1)-D tensor.
|
||||
Args:
|
||||
labels (torch.Tensor) : tensor with labels of shape :math:`(N, *)`,
|
||||
where N is batch size. Each value is an integer
|
||||
representing correct classification.
|
||||
num_classes (int): number of classes in labels.
|
||||
device (Optional[torch.device]): the desired device of returned tensor.
|
||||
Default: if None, uses the current device for the default tensor type
|
||||
(see torch.set_default_tensor_type()). device will be the CPU for CPU
|
||||
tensor types and the current CUDA device for CUDA tensor types.
|
||||
dtype (Optional[torch.dtype]): the desired data type of returned
|
||||
tensor. Default: if None, infers data type from values.
|
||||
Returns:
|
||||
torch.Tensor: the labels in one hot tensor of shape :math:`(N, C, *)`,
|
||||
Examples::
|
||||
>>> labels = torch.LongTensor([[[0, 1], [2, 0]]])
|
||||
>>> one_hot(labels, num_classes=3)
|
||||
tensor([[[[1., 0.],
|
||||
[0., 1.]],
|
||||
[[0., 1.],
|
||||
[0., 0.]],
|
||||
[[0., 0.],
|
||||
[1., 0.]]]]
|
||||
"""
|
||||
if not torch.is_tensor(labels):
|
||||
raise TypeError("Input labels type is not a torch.Tensor. Got {}"
|
||||
.format(type(labels)))
|
||||
if not labels.dtype == torch.int64:
|
||||
raise ValueError(
|
||||
"labels must be of the same dtype torch.int64. Got: {}" .format(
|
||||
labels.dtype))
|
||||
if num_classes < 1:
|
||||
raise ValueError("The number of classes must be bigger than one."
|
||||
" Got: {}".format(num_classes))
|
||||
device = labels.device
|
||||
shape = labels.shape
|
||||
one_hot = torch.zeros(shape[0], num_classes, *shape[1:],
|
||||
device=device, dtype=dtype)
|
||||
return one_hot.scatter_(1, labels.unsqueeze(1), 1.0)
|
|
@ -15,12 +15,13 @@ def build_optimizer(cfg, model):
|
|||
continue
|
||||
lr = cfg.SOLVER.BASE_LR
|
||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY
|
||||
# if "heads" in key:
|
||||
# lr = cfg.SOLVER.BASE_LR * 10
|
||||
if "heads" in key:
|
||||
lr *= cfg.SOLVER.HEADS_LR_FACTOR
|
||||
if "bias" in key:
|
||||
lr = lr * cfg.SOLVER.BIAS_LR_FACTOR
|
||||
lr *= cfg.SOLVER.BIAS_LR_FACTOR
|
||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
|
||||
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
|
||||
|
||||
solver_opt = cfg.SOLVER.OPT
|
||||
if hasattr(optim, solver_opt):
|
||||
if solver_opt == "SGD":
|
||||
|
|
Loading…
Reference in New Issue