mirror of https://github.com/JDAI-CV/fast-reid.git
201 lines
6.7 KiB
Python
201 lines
6.7 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from fastreid.config import configurable
|
|
from fastreid.modeling.backbones import build_backbone
|
|
from fastreid.modeling.heads import build_heads
|
|
from fastreid.modeling.losses import *
|
|
from .build import META_ARCH_REGISTRY
|
|
|
|
|
|
@META_ARCH_REGISTRY.register()
|
|
class Baseline(nn.Module):
|
|
"""
|
|
Baseline architecture. Any models that contains the following two components:
|
|
1. Per-image feature extraction (aka backbone)
|
|
2. Per-image feature aggregation and loss computation
|
|
"""
|
|
|
|
@configurable
|
|
def __init__(
|
|
self,
|
|
*,
|
|
backbone,
|
|
heads,
|
|
pixel_mean,
|
|
pixel_std,
|
|
loss_kwargs=None
|
|
):
|
|
"""
|
|
NOTE: this interface is experimental.
|
|
|
|
Args:
|
|
backbone:
|
|
heads:
|
|
pixel_mean:
|
|
pixel_std:
|
|
"""
|
|
super().__init__()
|
|
# backbone
|
|
self.backbone = backbone
|
|
|
|
# head
|
|
self.heads = heads
|
|
|
|
self.loss_kwargs = loss_kwargs
|
|
|
|
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(1, -1, 1, 1), False)
|
|
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(1, -1, 1, 1), False)
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg):
|
|
backbone = build_backbone(cfg)
|
|
heads = build_heads(cfg)
|
|
return {
|
|
'backbone': backbone,
|
|
'heads': heads,
|
|
'pixel_mean': cfg.MODEL.PIXEL_MEAN,
|
|
'pixel_std': cfg.MODEL.PIXEL_STD,
|
|
'loss_kwargs':
|
|
{
|
|
# loss name
|
|
'loss_names': cfg.MODEL.LOSSES.NAME,
|
|
|
|
# loss hyperparameters
|
|
'ce': {
|
|
'eps': cfg.MODEL.LOSSES.CE.EPSILON,
|
|
'alpha': cfg.MODEL.LOSSES.CE.ALPHA,
|
|
'scale': cfg.MODEL.LOSSES.CE.SCALE
|
|
},
|
|
'tri': {
|
|
'margin': cfg.MODEL.LOSSES.TRI.MARGIN,
|
|
'norm_feat': cfg.MODEL.LOSSES.TRI.NORM_FEAT,
|
|
'hard_mining': cfg.MODEL.LOSSES.TRI.HARD_MINING,
|
|
'scale': cfg.MODEL.LOSSES.TRI.SCALE
|
|
},
|
|
'circle': {
|
|
'margin': cfg.MODEL.LOSSES.CIRCLE.MARGIN,
|
|
'gamma': cfg.MODEL.LOSSES.CIRCLE.GAMMA,
|
|
'scale': cfg.MODEL.LOSSES.CIRCLE.SCALE
|
|
},
|
|
'cosface': {
|
|
'margin': cfg.MODEL.LOSSES.COSFACE.MARGIN,
|
|
'gamma': cfg.MODEL.LOSSES.COSFACE.GAMMA,
|
|
'scale': cfg.MODEL.LOSSES.COSFACE.SCALE
|
|
},
|
|
'contrastive': {
|
|
'margin': cfg.MODEL.LOSSES.CONTRASTIVE.MARGIN,
|
|
'scale': cfg.MODEL.LOSSES.CONTRASTIVE.SCALE
|
|
}
|
|
}
|
|
}
|
|
|
|
@property
|
|
def device(self):
|
|
return self.pixel_mean.device
|
|
|
|
def forward(self, batched_inputs):
|
|
images = self.preprocess_image(batched_inputs)
|
|
features = self.backbone(images)
|
|
|
|
if self.training:
|
|
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
|
|
targets = batched_inputs["targets"]
|
|
|
|
# PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
|
|
# may be larger than that in the original dataset, so the circle/arcface will
|
|
# throw an error. We just set all the targets to 0 to avoid this problem.
|
|
if targets.sum() < 0: targets.zero_()
|
|
|
|
outputs = self.heads(features, targets)
|
|
losses = self.losses(outputs, targets)
|
|
return losses
|
|
else:
|
|
outputs = self.heads(features)
|
|
return outputs
|
|
|
|
def preprocess_image(self, batched_inputs):
|
|
"""
|
|
Normalize and batch the input images.
|
|
"""
|
|
if isinstance(batched_inputs, dict):
|
|
images = batched_inputs['images']
|
|
elif isinstance(batched_inputs, torch.Tensor):
|
|
images = batched_inputs
|
|
else:
|
|
raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs)))
|
|
|
|
images.sub_(self.pixel_mean).div_(self.pixel_std)
|
|
return images
|
|
|
|
def losses(self, outputs, gt_labels):
|
|
"""
|
|
Compute loss from modeling's outputs, the loss function input arguments
|
|
must be the same as the outputs of the model forwarding.
|
|
"""
|
|
# model predictions
|
|
# fmt: off
|
|
pred_class_logits = outputs['pred_class_logits'].detach()
|
|
cls_outputs = outputs['cls_outputs']
|
|
pred_features = outputs['features']
|
|
# fmt: on
|
|
|
|
# Log prediction accuracy
|
|
# log_accuracy(pred_class_logits, gt_labels)
|
|
|
|
loss_dict = {}
|
|
loss_names = self.loss_kwargs['loss_names']
|
|
|
|
if 'CrossEntropyLoss' in loss_names:
|
|
ce_kwargs = self.loss_kwargs.get('ce')
|
|
loss_dict['loss_cls'] = cross_entropy_loss(
|
|
cls_outputs,
|
|
gt_labels,
|
|
ce_kwargs.get('eps'),
|
|
ce_kwargs.get('alpha')
|
|
) * ce_kwargs.get('scale')
|
|
|
|
if 'TripletLoss' in loss_names:
|
|
tri_kwargs = self.loss_kwargs.get('tri')
|
|
loss_dict['loss_triplet'] = triplet_loss(
|
|
pred_features,
|
|
gt_labels,
|
|
tri_kwargs.get('margin'),
|
|
tri_kwargs.get('norm_feat'),
|
|
tri_kwargs.get('hard_mining')
|
|
) * tri_kwargs.get('scale')
|
|
|
|
if 'CircleLoss' in loss_names:
|
|
circle_kwargs = self.loss_kwargs.get('circle')
|
|
loss_dict['loss_circle'] = pairwise_circleloss(
|
|
pred_features,
|
|
gt_labels,
|
|
circle_kwargs.get('margin'),
|
|
circle_kwargs.get('gamma')
|
|
) * circle_kwargs.get('scale')
|
|
|
|
if 'Cosface' in loss_names:
|
|
cosface_kwargs = self.loss_kwargs.get('cosface')
|
|
loss_dict['loss_cosface'] = pairwise_cosface(
|
|
pred_features,
|
|
gt_labels,
|
|
cosface_kwargs.get('margin'),
|
|
cosface_kwargs.get('gamma'),
|
|
) * cosface_kwargs.get('scale')
|
|
|
|
if 'ContrastiveLoss' in loss_names:
|
|
contrastive_kwargs = self.loss_kwargs.get('contrastive')
|
|
loss_dict['loss_contrastive'] = contrastive_loss(
|
|
pred_features,
|
|
gt_labels,
|
|
contrastive_kwargs.get('margin')
|
|
) * contrastive_kwargs.get('scale')
|
|
|
|
return loss_dict
|