# encoding: utf-8 """ @author: liaoxingyu @contact: sherlockliao01@gmail.com """ import torch from torch import nn from fastreid.config import configurable from fastreid.modeling.backbones import build_backbone from fastreid.modeling.heads import build_heads from fastreid.modeling.losses import * from .build import META_ARCH_REGISTRY @META_ARCH_REGISTRY.register() class Baseline(nn.Module): """ Baseline architecture. Any models that contains the following two components: 1. Per-image feature extraction (aka backbone) 2. Per-image feature aggregation and loss computation """ @configurable def __init__( self, *, backbone, heads, pixel_mean, pixel_std, loss_kwargs=None ): """ NOTE: this interface is experimental. Args: backbone: heads: pixel_mean: pixel_std: """ super().__init__() # backbone self.backbone = backbone # head self.heads = heads self.loss_kwargs = loss_kwargs self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(1, -1, 1, 1), False) self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(1, -1, 1, 1), False) @classmethod def from_config(cls, cfg): backbone = build_backbone(cfg) heads = build_heads(cfg) return { 'backbone': backbone, 'heads': heads, 'pixel_mean': cfg.MODEL.PIXEL_MEAN, 'pixel_std': cfg.MODEL.PIXEL_STD, 'loss_kwargs': { # loss name 'loss_names': cfg.MODEL.LOSSES.NAME, # loss hyperparameters 'ce': { 'eps': cfg.MODEL.LOSSES.CE.EPSILON, 'alpha': cfg.MODEL.LOSSES.CE.ALPHA, 'scale': cfg.MODEL.LOSSES.CE.SCALE }, 'tri': { 'margin': cfg.MODEL.LOSSES.TRI.MARGIN, 'norm_feat': cfg.MODEL.LOSSES.TRI.NORM_FEAT, 'hard_mining': cfg.MODEL.LOSSES.TRI.HARD_MINING, 'scale': cfg.MODEL.LOSSES.TRI.SCALE }, 'circle': { 'margin': cfg.MODEL.LOSSES.CIRCLE.MARGIN, 'gamma': cfg.MODEL.LOSSES.CIRCLE.GAMMA, 'scale': cfg.MODEL.LOSSES.CIRCLE.SCALE }, 'cosface': { 'margin': cfg.MODEL.LOSSES.COSFACE.MARGIN, 'gamma': cfg.MODEL.LOSSES.COSFACE.GAMMA, 'scale': cfg.MODEL.LOSSES.COSFACE.SCALE }, 'contrastive': { 'margin': cfg.MODEL.LOSSES.CONTRASTIVE.MARGIN, 'scale': cfg.MODEL.LOSSES.CONTRASTIVE.SCALE } } } @property def device(self): return self.pixel_mean.device def forward(self, batched_inputs): images = self.preprocess_image(batched_inputs) features = self.backbone(images) if self.training: assert "targets" in batched_inputs, "Person ID annotation are missing in training!" targets = batched_inputs["targets"] # PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset # may be larger than that in the original dataset, so the circle/arcface will # throw an error. We just set all the targets to 0 to avoid this problem. if targets.sum() < 0: targets.zero_() outputs = self.heads(features, targets) losses = self.losses(outputs, targets) return losses else: outputs = self.heads(features) return outputs def preprocess_image(self, batched_inputs): """ Normalize and batch the input images. """ if isinstance(batched_inputs, dict): images = batched_inputs['images'] elif isinstance(batched_inputs, torch.Tensor): images = batched_inputs else: raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs))) images.sub_(self.pixel_mean).div_(self.pixel_std) return images def losses(self, outputs, gt_labels): """ Compute loss from modeling's outputs, the loss function input arguments must be the same as the outputs of the model forwarding. """ # model predictions # fmt: off pred_class_logits = outputs['pred_class_logits'].detach() cls_outputs = outputs['cls_outputs'] pred_features = outputs['features'] # fmt: on # Log prediction accuracy # log_accuracy(pred_class_logits, gt_labels) loss_dict = {} loss_names = self.loss_kwargs['loss_names'] if 'CrossEntropyLoss' in loss_names: ce_kwargs = self.loss_kwargs.get('ce') loss_dict['loss_cls'] = cross_entropy_loss( cls_outputs, gt_labels, ce_kwargs.get('eps'), ce_kwargs.get('alpha') ) * ce_kwargs.get('scale') if 'TripletLoss' in loss_names: tri_kwargs = self.loss_kwargs.get('tri') loss_dict['loss_triplet'] = triplet_loss( pred_features, gt_labels, tri_kwargs.get('margin'), tri_kwargs.get('norm_feat'), tri_kwargs.get('hard_mining') ) * tri_kwargs.get('scale') if 'CircleLoss' in loss_names: circle_kwargs = self.loss_kwargs.get('circle') loss_dict['loss_circle'] = pairwise_circleloss( pred_features, gt_labels, circle_kwargs.get('margin'), circle_kwargs.get('gamma') ) * circle_kwargs.get('scale') if 'Cosface' in loss_names: cosface_kwargs = self.loss_kwargs.get('cosface') loss_dict['loss_cosface'] = pairwise_cosface( pred_features, gt_labels, cosface_kwargs.get('margin'), cosface_kwargs.get('gamma'), ) * cosface_kwargs.get('scale') if 'ContrastiveLoss' in loss_names: contrastive_kwargs = self.loss_kwargs.get('contrastive') loss_dict['loss_contrastive'] = contrastive_loss( pred_features, gt_labels, contrastive_kwargs.get('margin') ) * contrastive_kwargs.get('scale') return loss_dict