update focal loss

update dataset info display
update seperate lr
update adaptive label smooth regularization
pull/43/head
liaoxingyu 2020-04-17 13:46:10 +08:00
parent 9cf222e093
commit be9faa5605
9 changed files with 244 additions and 52 deletions

View File

@ -48,6 +48,16 @@ _C.MODEL.HEADS.NAME = "BNneckHead"
_C.MODEL.HEADS.POOL_LAYER = 'avgpool'
_C.MODEL.HEADS.NUM_CLASSES = 751
# Arcface head
_C.MODEL.HEADS.ARCFACE = CN()
_C.MODEL.HEADS.ARCFACE.MARGIN = 0.5
_C.MODEL.HEADS.ARCFACE.SCALE = 30.0
# Circle Loss
_C.MODEL.HEADS.CIRCLE = CN()
_C.MODEL.HEADS.CIRCLE.MARGIN = 0.15
_C.MODEL.HEADS.CIRCLE.SCALE = 128.0
# ---------------------------------------------------------------------------- #
# REID LOSSES options
# ---------------------------------------------------------------------------- #
@ -55,16 +65,26 @@ _C.MODEL.LOSSES = CN()
_C.MODEL.LOSSES.NAME = ("CrossEntropyLoss",)
# Cross Entropy Loss options
_C.MODEL.LOSSES.SMOOTH_ON = False
_C.MODEL.LOSSES.EPSILON = 0.1
_C.MODEL.LOSSES.SCALE_CE = 1.0
_C.MODEL.LOSSES.CE = CN()
# if epsilon == 0, it means no label smooth regularization,
# if epsilon == -1, it means adaptive label smooth regularization
_C.MODEL.LOSSES.CE.EPSILON = 0.0
_C.MODEL.LOSSES.CE.ALPHA = 0.2
_C.MODEL.LOSSES.CE.SCALE = 1.0
# Triplet Loss options
_C.MODEL.LOSSES.MARGIN = 0.3
_C.MODEL.LOSSES.NORM_FEAT = False
_C.MODEL.LOSSES.HARD_MINING = True
_C.MODEL.LOSSES.USE_COSINE_DIST = True
_C.MODEL.LOSSES.SCALE_TRI = 1.0
_C.MODEL.LOSSES.TRI = CN()
_C.MODEL.LOSSES.TRI.MARGIN = 0.3
_C.MODEL.LOSSES.TRI.NORM_FEAT = False
_C.MODEL.LOSSES.TRI.HARD_MINING = True
_C.MODEL.LOSSES.TRI.USE_COSINE_DIST = True
_C.MODEL.LOSSES.TRI.SCALE = 1.0
# Focal Loss options
_C.MODEL.LOSSES.FL = CN()
_C.MODEL.LOSSES.FL.ALPHA = 0.25
_C.MODEL.LOSSES.FL.GAMMA = 2
_C.MODEL.LOSSES.FL.SCALE = 1.0
# Path (possibly with schema like catalog:// or detectron2://) to a checkpoint file
# to be loaded to the model. You can find available models in the model zoo.
@ -136,7 +156,8 @@ _C.SOLVER.OPT = "Adam"
_C.SOLVER.MAX_ITER = 40000
_C.SOLVER.BASE_LR = 3e-4
_C.SOLVER.BIAS_LR_FACTOR = 1
_C.SOLVER.BIAS_LR_FACTOR = 1.
_C.SOLVER.HEADS_LR_FACTOR = 1.
_C.SOLVER.MOMENTUM = 0.9
@ -144,11 +165,11 @@ _C.SOLVER.WEIGHT_DECAY = 0.0005
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.
_C.SOLVER.SCHED = "warmup"
# warmup config
# Warmup config
_C.SOLVER.GAMMA = 0.1
_C.SOLVER.STEPS = (30, 55)
# cosine annealing
# Cosine annealing
_C.SOLVER.DELAY_ITERS = 100
_C.SOLVER.COS_ANNEAL_ITERS = 100

View File

@ -23,6 +23,7 @@ def build_reid_train_loader(cfg):
for d in cfg.DATASETS.NAMES:
logger.info('prepare training set {}'.format(d))
dataset = DATASET_REGISTRY.get(d)()
dataset.show_summary()
train_items.extend(dataset.train)
train_set = CommDataset(train_items, train_transforms, relabel=True)
@ -52,6 +53,7 @@ def build_reid_test_loader(cfg, dataset_name):
logger = logging.getLogger(__name__)
logger.info('prepare test set {}'.format(dataset_name))
dataset = DATASET_REGISTRY.get(dataset_name)()
dataset.show_summary()
test_items = dataset.query + dataset.gallery
test_set = CommDataset(test_items, test_transforms, relabel=False)

View File

@ -9,6 +9,7 @@ import os
import numpy as np
import torch
import logging
class Dataset(object):
@ -197,19 +198,19 @@ class ImageDataset(Dataset):
super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
def show_summary(self):
logger = logging.getLogger(__name__)
num_train_pids, num_train_cams = self.parse_data(self.train)
num_query_pids, num_query_cams = self.parse_data(self.query)
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
print('=> Loaded {}'.format(self.__class__.__name__))
print(' ----------------------------------------')
print(' subset | # ids | # images | # cameras')
print(' ----------------------------------------')
print(' train | {:5d} | {:8d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams))
print(' query | {:5d} | {:8d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams))
print(' gallery | {:5d} | {:8d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams))
print(' ----------------------------------------')
logger.info('=> Loaded {}'.format(self.__class__.__name__))
logger.info(' ----------------------------------------')
logger.info(' subset | # ids | # images | # cameras')
logger.info(' ----------------------------------------')
logger.info(' train | {:5d} | {:8d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams))
logger.info(' query | {:5d} | {:8d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams))
logger.info(' gallery | {:5d} | {:8d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams))
logger.info(' ----------------------------------------')
# class VideoDataset(Dataset):
# """A base class representing VideoDataset.

View File

@ -13,6 +13,7 @@ from torch.nn import Parameter
from .build import REID_HEADS_REGISTRY
from .linear_head import LinearHead
from ..losses.loss_utils import one_hot
from ..model_utils import weights_init_kaiming
from ...layers import NoBiasBatchNorm1d, Flatten
@ -32,9 +33,8 @@ class ArcfaceHead(nn.Module):
self.bnneck.apply(weights_init_kaiming)
# classifier
# self.adaptive_s = False
self._s = 30.0
self._m = 0.50
self._s = cfg.MODEL.HEADS.ARCFACE.SCALE
self._m = cfg.MODEL.HEADS.ARCFACE.MARGIN
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
self.reset_parameters()
@ -52,24 +52,13 @@ class ArcfaceHead(nn.Module):
cosine = F.linear(F.normalize(bn_feat), F.normalize(self.weight))
# add margin
theta = torch.acos(torch.clamp(cosine, -1.0+1e-7, 1.0-1e-7))
theta = torch.acos(torch.clamp(cosine, -1.0 + 1e-7, 1.0 - 1e-7))
phi = torch.cos(theta + self._m)
# --------------------------- convert label to one-hot ---------------------------
one_hot = torch.zeros_like(cosine)
one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
# if self.adaptive_s:
# with torch.no_grad():
# B_avg = torch.where(one_hot < 1, torch.exp(self._s * cosine), torch.zeros_like(cosine))
# B_avg = torch.sum(B_avg) / cosine.size(0)
# theta_med = torch.median(theta[one_hot == 1])
# s = torch.log(B_avg) / torch.cos(torch.min(math.pi/4 * torch.ones_like(theta_med)))
# else:
# s = self._s
# you can use torch.where if your torch.__version__ is 0.4
pred_class_logits = one_hot * phi + (1.0 - one_hot) * cosine
targets = one_hot(targets, self._num_classes)
pred_class_logits = targets * phi + (1.0 - targets) * cosine
# logits re-scale
pred_class_logits *= self._s

View File

@ -33,8 +33,8 @@ class CircleHead(nn.Module):
self.bnneck.apply(weights_init_kaiming)
# classifier
self._s = 128.0
self._m = 0.15
self._s = cfg.MODEL.HEADS.CIRCLE.SCALE
self._m = cfg.MODEL.HEADS.CIRCLE.MARGIN
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
self.reset_parameters()

View File

@ -7,6 +7,7 @@ import torch
import torch.nn.functional as F
from ...utils.events import get_event_storage
from .loss_utils import one_hot
class CrossEntropyLoss(object):
@ -16,9 +17,9 @@ class CrossEntropyLoss(object):
def __init__(self, cfg):
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self._epsilon = cfg.MODEL.LOSSES.EPSILON
self._smooth_on = cfg.MODEL.LOSSES.SMOOTH_ON
self._scale = cfg.MODEL.LOSSES.SCALE_CE
self._eps = cfg.MODEL.LOSSES.CE.EPSILON
self._alpha = cfg.MODEL.LOSSES.CE.ALPHA
self._scale = cfg.MODEL.LOSSES.CE.SCALE
self._topk = (1,)
@ -47,14 +48,20 @@ class CrossEntropyLoss(object):
scalar Tensor
"""
self._log_accuracy(pred_class_logits, gt_classes)
if self._smooth_on:
log_probs = F.log_softmax(pred_class_logits, dim=1)
targets = torch.zeros(log_probs.size()).scatter_(1, gt_classes.unsqueeze(1).data.cpu(), 1)
targets = targets.to(pred_class_logits.device)
targets = (1 - self._epsilon) * targets + self._epsilon / self._num_classes
loss = (-targets * log_probs).mean(0).sum()
if self._eps >= 0:
smooth_param = self._eps
else:
loss = F.cross_entropy(pred_class_logits, gt_classes, reduction="mean")
# adaptive lsr
soft_label = F.softmax(pred_class_logits, dim=1)
smooth_param = self._alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1)
log_probs = F.log_softmax(pred_class_logits, dim=1)
with torch.no_grad():
targets = torch.ones_like(log_probs)
targets *= smooth_param / (self._num_classes - 1)
targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param))
loss = (-targets * log_probs).mean(0).sum()
return {
"loss_cls": loss * self._scale,
}

View File

@ -0,0 +1,114 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F
from .loss_utils import one_hot
# based on:
# https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py
def focal_loss(
input: torch.Tensor,
target: torch.Tensor,
alpha: float,
gamma: float = 2.0,
reduction: str = 'mean', ) -> torch.Tensor:
r"""Function that computes Focal loss.
See :class:`fastreid.modeling.losses.FocalLoss` for details.
"""
if not torch.is_tensor(input):
raise TypeError("Input type is not a torch.Tensor. Got {}"
.format(type(input)))
if not len(input.shape) >= 2:
raise ValueError("Invalid input shape, we expect BxCx*. Got: {}"
.format(input.shape))
if input.size(0) != target.size(0):
raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
.format(input.size(0), target.size(0)))
n = input.size(0)
out_size = (n,) + input.size()[2:]
if target.size()[1:] != input.size()[2:]:
raise ValueError('Expected target size {}, got {}'.format(
out_size, target.size()))
if not input.device == target.device:
raise ValueError(
"input and target must be in the same device. Got: {}".format(
input.device, target.device))
# compute softmax over the classes axis
input_soft = F.softmax(input, dim=1)
# create the labels one hot tensor
target_one_hot = one_hot(
target, num_classes=input.shape[1],
dtype=input.dtype)
# compute the actual focal loss
weight = torch.pow(-input_soft + 1., gamma)
focal = -alpha * weight * torch.log(input_soft)
loss_tmp = torch.sum(target_one_hot * focal, dim=1)
if reduction == 'none':
loss = loss_tmp
elif reduction == 'mean':
loss = torch.mean(loss_tmp)
elif reduction == 'sum':
loss = torch.sum(loss_tmp)
else:
raise NotImplementedError("Invalid reduction mode: {}"
.format(reduction))
return loss
class FocalLoss(object):
r"""Criterion that computes Focal loss.
According to [1], the Focal loss is computed as follows:
.. math::
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
where:
- :math:`p_t` is the model's estimated probability for each class.
Arguments:
alpha (float): Weighting factor :math:`\alpha \in [0, 1]`.
gamma (float): Focusing parameter :math:`\gamma >= 0`.
reduction (str, optional): Specifies the reduction to apply to the
output: none | mean | sum. none: no reduction will be applied,
mean: the sum of the output will be divided by the number of elements
in the output, sum: the output will be summed. Default: none.
Shape:
- Input: :math:`(N, C, *)` where C = number of classes.
- Target: :math:`(N, *)` where each value is
:math:`0 targets[i] C1`.
Examples:
>>> N = 5 # num_classes
>>> loss = FocalLoss(cfg)
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = loss(input, target)
>>> output.backward()
References:
[1] https://arxiv.org/abs/1708.02002
"""
# def __init__(self, alpha: float, gamma: float = 2.0,
# reduction: str = 'none') -> None:
def __init__(self, cfg):
self._alpha: float = cfg.MODEL.LOSSES.FL.ALPHA
self._gamma: float = cfg.MODEL.LOSSES.FL.GAMMA
self._scale: float = cfg.MODEL.LOSSES.FL.SCALE
def __call__(self, pred_class_logits: torch.Tensor, _, gt_classes: torch.Tensor) -> dict:
loss = focal_loss(pred_class_logits, gt_classes, self._alpha, self._gamma)
return {
'loss_focal': loss * self._scale,
}

View File

@ -0,0 +1,57 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: liaoxingyu5@jd.com
"""
from typing import Optional
import torch
# based on:
# https://github.com/kornia/kornia/blob/master/kornia/utils/one_hot.py
def one_hot(labels: torch.Tensor,
num_classes: int,
dtype: Optional[torch.dtype] = None,) -> torch.Tensor:
# eps: Optional[float] = 1e-6) -> torch.Tensor:
r"""Converts an integer label x-D tensor to a one-hot (x+1)-D tensor.
Args:
labels (torch.Tensor) : tensor with labels of shape :math:`(N, *)`,
where N is batch size. Each value is an integer
representing correct classification.
num_classes (int): number of classes in labels.
device (Optional[torch.device]): the desired device of returned tensor.
Default: if None, uses the current device for the default tensor type
(see torch.set_default_tensor_type()). device will be the CPU for CPU
tensor types and the current CUDA device for CUDA tensor types.
dtype (Optional[torch.dtype]): the desired data type of returned
tensor. Default: if None, infers data type from values.
Returns:
torch.Tensor: the labels in one hot tensor of shape :math:`(N, C, *)`,
Examples::
>>> labels = torch.LongTensor([[[0, 1], [2, 0]]])
>>> one_hot(labels, num_classes=3)
tensor([[[[1., 0.],
[0., 1.]],
[[0., 1.],
[0., 0.]],
[[0., 0.],
[1., 0.]]]]
"""
if not torch.is_tensor(labels):
raise TypeError("Input labels type is not a torch.Tensor. Got {}"
.format(type(labels)))
if not labels.dtype == torch.int64:
raise ValueError(
"labels must be of the same dtype torch.int64. Got: {}" .format(
labels.dtype))
if num_classes < 1:
raise ValueError("The number of classes must be bigger than one."
" Got: {}".format(num_classes))
device = labels.device
shape = labels.shape
one_hot = torch.zeros(shape[0], num_classes, *shape[1:],
device=device, dtype=dtype)
return one_hot.scatter_(1, labels.unsqueeze(1), 1.0)

View File

@ -15,12 +15,13 @@ def build_optimizer(cfg, model):
continue
lr = cfg.SOLVER.BASE_LR
weight_decay = cfg.SOLVER.WEIGHT_DECAY
# if "heads" in key:
# lr = cfg.SOLVER.BASE_LR * 10
if "heads" in key:
lr *= cfg.SOLVER.HEADS_LR_FACTOR
if "bias" in key:
lr = lr * cfg.SOLVER.BIAS_LR_FACTOR
lr *= cfg.SOLVER.BIAS_LR_FACTOR
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
solver_opt = cfg.SOLVER.OPT
if hasattr(optim, solver_opt):
if solver_opt == "SGD":