From be9faa5605098d25a6d6cdb42e787860a2106b48 Mon Sep 17 00:00:00 2001 From: liaoxingyu Date: Fri, 17 Apr 2020 13:46:10 +0800 Subject: [PATCH] update focal loss update dataset info display update seperate lr update adaptive label smooth regularization --- fastreid/config/defaults.py | 43 +++++-- fastreid/data/build.py | 2 + fastreid/data/datasets/bases.py | 19 +-- fastreid/modeling/heads/arcface_head.py | 23 +--- fastreid/modeling/heads/circle_head.py | 4 +- fastreid/modeling/losses/cross_entroy_loss.py | 27 +++-- fastreid/modeling/losses/focal_loss.py | 114 ++++++++++++++++++ fastreid/modeling/losses/loss_utils.py | 57 +++++++++ fastreid/solver/build.py | 7 +- 9 files changed, 244 insertions(+), 52 deletions(-) create mode 100644 fastreid/modeling/losses/focal_loss.py create mode 100644 fastreid/modeling/losses/loss_utils.py diff --git a/fastreid/config/defaults.py b/fastreid/config/defaults.py index aef27d2..1f62614 100644 --- a/fastreid/config/defaults.py +++ b/fastreid/config/defaults.py @@ -48,6 +48,16 @@ _C.MODEL.HEADS.NAME = "BNneckHead" _C.MODEL.HEADS.POOL_LAYER = 'avgpool' _C.MODEL.HEADS.NUM_CLASSES = 751 +# Arcface head +_C.MODEL.HEADS.ARCFACE = CN() +_C.MODEL.HEADS.ARCFACE.MARGIN = 0.5 +_C.MODEL.HEADS.ARCFACE.SCALE = 30.0 + +# Circle Loss +_C.MODEL.HEADS.CIRCLE = CN() +_C.MODEL.HEADS.CIRCLE.MARGIN = 0.15 +_C.MODEL.HEADS.CIRCLE.SCALE = 128.0 + # ---------------------------------------------------------------------------- # # REID LOSSES options # ---------------------------------------------------------------------------- # @@ -55,16 +65,26 @@ _C.MODEL.LOSSES = CN() _C.MODEL.LOSSES.NAME = ("CrossEntropyLoss",) # Cross Entropy Loss options -_C.MODEL.LOSSES.SMOOTH_ON = False -_C.MODEL.LOSSES.EPSILON = 0.1 -_C.MODEL.LOSSES.SCALE_CE = 1.0 +_C.MODEL.LOSSES.CE = CN() +# if epsilon == 0, it means no label smooth regularization, +# if epsilon == -1, it means adaptive label smooth regularization +_C.MODEL.LOSSES.CE.EPSILON = 0.0 +_C.MODEL.LOSSES.CE.ALPHA = 0.2 +_C.MODEL.LOSSES.CE.SCALE = 1.0 # Triplet Loss options -_C.MODEL.LOSSES.MARGIN = 0.3 -_C.MODEL.LOSSES.NORM_FEAT = False -_C.MODEL.LOSSES.HARD_MINING = True -_C.MODEL.LOSSES.USE_COSINE_DIST = True -_C.MODEL.LOSSES.SCALE_TRI = 1.0 +_C.MODEL.LOSSES.TRI = CN() +_C.MODEL.LOSSES.TRI.MARGIN = 0.3 +_C.MODEL.LOSSES.TRI.NORM_FEAT = False +_C.MODEL.LOSSES.TRI.HARD_MINING = True +_C.MODEL.LOSSES.TRI.USE_COSINE_DIST = True +_C.MODEL.LOSSES.TRI.SCALE = 1.0 + +# Focal Loss options +_C.MODEL.LOSSES.FL = CN() +_C.MODEL.LOSSES.FL.ALPHA = 0.25 +_C.MODEL.LOSSES.FL.GAMMA = 2 +_C.MODEL.LOSSES.FL.SCALE = 1.0 # Path (possibly with schema like catalog:// or detectron2://) to a checkpoint file # to be loaded to the model. You can find available models in the model zoo. @@ -136,7 +156,8 @@ _C.SOLVER.OPT = "Adam" _C.SOLVER.MAX_ITER = 40000 _C.SOLVER.BASE_LR = 3e-4 -_C.SOLVER.BIAS_LR_FACTOR = 1 +_C.SOLVER.BIAS_LR_FACTOR = 1. +_C.SOLVER.HEADS_LR_FACTOR = 1. _C.SOLVER.MOMENTUM = 0.9 @@ -144,11 +165,11 @@ _C.SOLVER.WEIGHT_DECAY = 0.0005 _C.SOLVER.WEIGHT_DECAY_BIAS = 0. _C.SOLVER.SCHED = "warmup" -# warmup config +# Warmup config _C.SOLVER.GAMMA = 0.1 _C.SOLVER.STEPS = (30, 55) -# cosine annealing +# Cosine annealing _C.SOLVER.DELAY_ITERS = 100 _C.SOLVER.COS_ANNEAL_ITERS = 100 diff --git a/fastreid/data/build.py b/fastreid/data/build.py index da5b4b0..1ef3407 100644 --- a/fastreid/data/build.py +++ b/fastreid/data/build.py @@ -23,6 +23,7 @@ def build_reid_train_loader(cfg): for d in cfg.DATASETS.NAMES: logger.info('prepare training set {}'.format(d)) dataset = DATASET_REGISTRY.get(d)() + dataset.show_summary() train_items.extend(dataset.train) train_set = CommDataset(train_items, train_transforms, relabel=True) @@ -52,6 +53,7 @@ def build_reid_test_loader(cfg, dataset_name): logger = logging.getLogger(__name__) logger.info('prepare test set {}'.format(dataset_name)) dataset = DATASET_REGISTRY.get(dataset_name)() + dataset.show_summary() test_items = dataset.query + dataset.gallery test_set = CommDataset(test_items, test_transforms, relabel=False) diff --git a/fastreid/data/datasets/bases.py b/fastreid/data/datasets/bases.py index 21d3266..dc30663 100644 --- a/fastreid/data/datasets/bases.py +++ b/fastreid/data/datasets/bases.py @@ -9,6 +9,7 @@ import os import numpy as np import torch +import logging class Dataset(object): @@ -197,19 +198,19 @@ class ImageDataset(Dataset): super(ImageDataset, self).__init__(train, query, gallery, **kwargs) def show_summary(self): + logger = logging.getLogger(__name__) num_train_pids, num_train_cams = self.parse_data(self.train) num_query_pids, num_query_cams = self.parse_data(self.query) num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery) - print('=> Loaded {}'.format(self.__class__.__name__)) - print(' ----------------------------------------') - print(' subset | # ids | # images | # cameras') - print(' ----------------------------------------') - print(' train | {:5d} | {:8d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams)) - print(' query | {:5d} | {:8d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams)) - print(' gallery | {:5d} | {:8d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams)) - print(' ----------------------------------------') - + logger.info('=> Loaded {}'.format(self.__class__.__name__)) + logger.info(' ----------------------------------------') + logger.info(' subset | # ids | # images | # cameras') + logger.info(' ----------------------------------------') + logger.info(' train | {:5d} | {:8d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams)) + logger.info(' query | {:5d} | {:8d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams)) + logger.info(' gallery | {:5d} | {:8d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams)) + logger.info(' ----------------------------------------') # class VideoDataset(Dataset): # """A base class representing VideoDataset. diff --git a/fastreid/modeling/heads/arcface_head.py b/fastreid/modeling/heads/arcface_head.py index 14e8882..b1d1056 100644 --- a/fastreid/modeling/heads/arcface_head.py +++ b/fastreid/modeling/heads/arcface_head.py @@ -13,6 +13,7 @@ from torch.nn import Parameter from .build import REID_HEADS_REGISTRY from .linear_head import LinearHead +from ..losses.loss_utils import one_hot from ..model_utils import weights_init_kaiming from ...layers import NoBiasBatchNorm1d, Flatten @@ -32,9 +33,8 @@ class ArcfaceHead(nn.Module): self.bnneck.apply(weights_init_kaiming) # classifier - # self.adaptive_s = False - self._s = 30.0 - self._m = 0.50 + self._s = cfg.MODEL.HEADS.ARCFACE.SCALE + self._m = cfg.MODEL.HEADS.ARCFACE.MARGIN self.weight = Parameter(torch.Tensor(self._num_classes, in_feat)) self.reset_parameters() @@ -52,24 +52,13 @@ class ArcfaceHead(nn.Module): cosine = F.linear(F.normalize(bn_feat), F.normalize(self.weight)) # add margin - theta = torch.acos(torch.clamp(cosine, -1.0+1e-7, 1.0-1e-7)) + theta = torch.acos(torch.clamp(cosine, -1.0 + 1e-7, 1.0 - 1e-7)) phi = torch.cos(theta + self._m) # --------------------------- convert label to one-hot --------------------------- - one_hot = torch.zeros_like(cosine) - one_hot.scatter_(1, targets.view(-1, 1).long(), 1) - - # if self.adaptive_s: - # with torch.no_grad(): - # B_avg = torch.where(one_hot < 1, torch.exp(self._s * cosine), torch.zeros_like(cosine)) - # B_avg = torch.sum(B_avg) / cosine.size(0) - # theta_med = torch.median(theta[one_hot == 1]) - # s = torch.log(B_avg) / torch.cos(torch.min(math.pi/4 * torch.ones_like(theta_med))) - # else: - # s = self._s - # you can use torch.where if your torch.__version__ is 0.4 - pred_class_logits = one_hot * phi + (1.0 - one_hot) * cosine + targets = one_hot(targets, self._num_classes) + pred_class_logits = targets * phi + (1.0 - targets) * cosine # logits re-scale pred_class_logits *= self._s diff --git a/fastreid/modeling/heads/circle_head.py b/fastreid/modeling/heads/circle_head.py index 0e0289d..96c5bb7 100644 --- a/fastreid/modeling/heads/circle_head.py +++ b/fastreid/modeling/heads/circle_head.py @@ -33,8 +33,8 @@ class CircleHead(nn.Module): self.bnneck.apply(weights_init_kaiming) # classifier - self._s = 128.0 - self._m = 0.15 + self._s = cfg.MODEL.HEADS.CIRCLE.SCALE + self._m = cfg.MODEL.HEADS.CIRCLE.MARGIN self.weight = Parameter(torch.Tensor(self._num_classes, in_feat)) self.reset_parameters() diff --git a/fastreid/modeling/losses/cross_entroy_loss.py b/fastreid/modeling/losses/cross_entroy_loss.py index afc6f90..4945aa4 100644 --- a/fastreid/modeling/losses/cross_entroy_loss.py +++ b/fastreid/modeling/losses/cross_entroy_loss.py @@ -7,6 +7,7 @@ import torch import torch.nn.functional as F from ...utils.events import get_event_storage +from .loss_utils import one_hot class CrossEntropyLoss(object): @@ -16,9 +17,9 @@ class CrossEntropyLoss(object): def __init__(self, cfg): self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES - self._epsilon = cfg.MODEL.LOSSES.EPSILON - self._smooth_on = cfg.MODEL.LOSSES.SMOOTH_ON - self._scale = cfg.MODEL.LOSSES.SCALE_CE + self._eps = cfg.MODEL.LOSSES.CE.EPSILON + self._alpha = cfg.MODEL.LOSSES.CE.ALPHA + self._scale = cfg.MODEL.LOSSES.CE.SCALE self._topk = (1,) @@ -47,14 +48,20 @@ class CrossEntropyLoss(object): scalar Tensor """ self._log_accuracy(pred_class_logits, gt_classes) - if self._smooth_on: - log_probs = F.log_softmax(pred_class_logits, dim=1) - targets = torch.zeros(log_probs.size()).scatter_(1, gt_classes.unsqueeze(1).data.cpu(), 1) - targets = targets.to(pred_class_logits.device) - targets = (1 - self._epsilon) * targets + self._epsilon / self._num_classes - loss = (-targets * log_probs).mean(0).sum() + if self._eps >= 0: + smooth_param = self._eps else: - loss = F.cross_entropy(pred_class_logits, gt_classes, reduction="mean") + # adaptive lsr + soft_label = F.softmax(pred_class_logits, dim=1) + smooth_param = self._alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1) + + log_probs = F.log_softmax(pred_class_logits, dim=1) + with torch.no_grad(): + targets = torch.ones_like(log_probs) + targets *= smooth_param / (self._num_classes - 1) + targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param)) + + loss = (-targets * log_probs).mean(0).sum() return { "loss_cls": loss * self._scale, } diff --git a/fastreid/modeling/losses/focal_loss.py b/fastreid/modeling/losses/focal_loss.py new file mode 100644 index 0000000..aec3118 --- /dev/null +++ b/fastreid/modeling/losses/focal_loss.py @@ -0,0 +1,114 @@ +# encoding: utf-8 +""" +@author: xingyu liao +@contact: sherlockliao01@gmail.com +""" + +import torch +import torch.nn.functional as F + +from .loss_utils import one_hot + + +# based on: +# https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py + +def focal_loss( + input: torch.Tensor, + target: torch.Tensor, + alpha: float, + gamma: float = 2.0, + reduction: str = 'mean', ) -> torch.Tensor: + r"""Function that computes Focal loss. + See :class:`fastreid.modeling.losses.FocalLoss` for details. + """ + if not torch.is_tensor(input): + raise TypeError("Input type is not a torch.Tensor. Got {}" + .format(type(input))) + + if not len(input.shape) >= 2: + raise ValueError("Invalid input shape, we expect BxCx*. Got: {}" + .format(input.shape)) + + if input.size(0) != target.size(0): + raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).' + .format(input.size(0), target.size(0))) + + n = input.size(0) + out_size = (n,) + input.size()[2:] + if target.size()[1:] != input.size()[2:]: + raise ValueError('Expected target size {}, got {}'.format( + out_size, target.size())) + + if not input.device == target.device: + raise ValueError( + "input and target must be in the same device. Got: {}".format( + input.device, target.device)) + + # compute softmax over the classes axis + input_soft = F.softmax(input, dim=1) + + # create the labels one hot tensor + target_one_hot = one_hot( + target, num_classes=input.shape[1], + dtype=input.dtype) + + # compute the actual focal loss + weight = torch.pow(-input_soft + 1., gamma) + + focal = -alpha * weight * torch.log(input_soft) + loss_tmp = torch.sum(target_one_hot * focal, dim=1) + + if reduction == 'none': + loss = loss_tmp + elif reduction == 'mean': + loss = torch.mean(loss_tmp) + elif reduction == 'sum': + loss = torch.sum(loss_tmp) + else: + raise NotImplementedError("Invalid reduction mode: {}" + .format(reduction)) + return loss + + +class FocalLoss(object): + r"""Criterion that computes Focal loss. + According to [1], the Focal loss is computed as follows: + .. math:: + \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t) + where: + - :math:`p_t` is the model's estimated probability for each class. + Arguments: + alpha (float): Weighting factor :math:`\alpha \in [0, 1]`. + gamma (float): Focusing parameter :math:`\gamma >= 0`. + reduction (str, optional): Specifies the reduction to apply to the + output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, + ‘mean’: the sum of the output will be divided by the number of elements + in the output, ‘sum’: the output will be summed. Default: ‘none’. + Shape: + - Input: :math:`(N, C, *)` where C = number of classes. + - Target: :math:`(N, *)` where each value is + :math:`0 ≤ targets[i] ≤ C−1`. + Examples: + >>> N = 5 # num_classes + >>> loss = FocalLoss(cfg) + >>> input = torch.randn(1, N, 3, 5, requires_grad=True) + >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) + >>> output = loss(input, target) + >>> output.backward() + References: + [1] https://arxiv.org/abs/1708.02002 + """ + + # def __init__(self, alpha: float, gamma: float = 2.0, + # reduction: str = 'none') -> None: + def __init__(self, cfg): + self._alpha: float = cfg.MODEL.LOSSES.FL.ALPHA + self._gamma: float = cfg.MODEL.LOSSES.FL.GAMMA + self._scale: float = cfg.MODEL.LOSSES.FL.SCALE + + def __call__(self, pred_class_logits: torch.Tensor, _, gt_classes: torch.Tensor) -> dict: + loss = focal_loss(pred_class_logits, gt_classes, self._alpha, self._gamma) + return { + 'loss_focal': loss * self._scale, + } diff --git a/fastreid/modeling/losses/loss_utils.py b/fastreid/modeling/losses/loss_utils.py new file mode 100644 index 0000000..c8d5548 --- /dev/null +++ b/fastreid/modeling/losses/loss_utils.py @@ -0,0 +1,57 @@ +# encoding: utf-8 +""" +@author: xingyu liao +@contact: liaoxingyu5@jd.com +""" + +from typing import Optional + +import torch + +# based on: +# https://github.com/kornia/kornia/blob/master/kornia/utils/one_hot.py + + +def one_hot(labels: torch.Tensor, + num_classes: int, + dtype: Optional[torch.dtype] = None,) -> torch.Tensor: + # eps: Optional[float] = 1e-6) -> torch.Tensor: + r"""Converts an integer label x-D tensor to a one-hot (x+1)-D tensor. + Args: + labels (torch.Tensor) : tensor with labels of shape :math:`(N, *)`, + where N is batch size. Each value is an integer + representing correct classification. + num_classes (int): number of classes in labels. + device (Optional[torch.device]): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type + (see torch.set_default_tensor_type()). device will be the CPU for CPU + tensor types and the current CUDA device for CUDA tensor types. + dtype (Optional[torch.dtype]): the desired data type of returned + tensor. Default: if None, infers data type from values. + Returns: + torch.Tensor: the labels in one hot tensor of shape :math:`(N, C, *)`, + Examples:: + >>> labels = torch.LongTensor([[[0, 1], [2, 0]]]) + >>> one_hot(labels, num_classes=3) + tensor([[[[1., 0.], + [0., 1.]], + [[0., 1.], + [0., 0.]], + [[0., 0.], + [1., 0.]]]] + """ + if not torch.is_tensor(labels): + raise TypeError("Input labels type is not a torch.Tensor. Got {}" + .format(type(labels))) + if not labels.dtype == torch.int64: + raise ValueError( + "labels must be of the same dtype torch.int64. Got: {}" .format( + labels.dtype)) + if num_classes < 1: + raise ValueError("The number of classes must be bigger than one." + " Got: {}".format(num_classes)) + device = labels.device + shape = labels.shape + one_hot = torch.zeros(shape[0], num_classes, *shape[1:], + device=device, dtype=dtype) + return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) \ No newline at end of file diff --git a/fastreid/solver/build.py b/fastreid/solver/build.py index a696f53..b58beaa 100644 --- a/fastreid/solver/build.py +++ b/fastreid/solver/build.py @@ -15,12 +15,13 @@ def build_optimizer(cfg, model): continue lr = cfg.SOLVER.BASE_LR weight_decay = cfg.SOLVER.WEIGHT_DECAY - # if "heads" in key: - # lr = cfg.SOLVER.BASE_LR * 10 + if "heads" in key: + lr *= cfg.SOLVER.HEADS_LR_FACTOR if "bias" in key: - lr = lr * cfg.SOLVER.BIAS_LR_FACTOR + lr *= cfg.SOLVER.BIAS_LR_FACTOR weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] + solver_opt = cfg.SOLVER.OPT if hasattr(optim, solver_opt): if solver_opt == "SGD":