fast-reid/fastreid/modeling/meta_arch/baseline.py

121 lines
4.1 KiB
Python
Raw Normal View History

2020-02-10 07:38:56 +08:00
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
2020-02-10 07:38:56 +08:00
from fastreid.modeling.backbones import build_backbone
from fastreid.modeling.heads import build_heads
2020-07-06 16:57:43 +08:00
from fastreid.modeling.losses import *
2020-02-10 07:38:56 +08:00
from .build import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class Baseline(nn.Module):
def __init__(self, cfg):
super().__init__()
2020-03-25 10:58:26 +08:00
self._cfg = cfg
2020-07-06 16:57:43 +08:00
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
2020-03-25 10:58:26 +08:00
# backbone
2020-02-10 07:38:56 +08:00
self.backbone = build_backbone(cfg)
2020-03-25 10:58:26 +08:00
# head
self.heads = build_heads(cfg)
@property
def device(self):
return self.pixel_mean.device
2020-02-10 07:38:56 +08:00
def forward(self, batched_inputs):
images = self.preprocess_image(batched_inputs)
2020-07-06 16:57:43 +08:00
features = self.backbone(images)
2020-07-06 16:57:43 +08:00
if self.training:
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
targets = batched_inputs["targets"].to(self.device)
2020-03-25 10:58:26 +08:00
2020-07-10 16:26:35 +08:00
# PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
# may be larger than that in the original dataset, so the circle/arcface will
# throw an error. We just set all the targets to 0 to avoid this problem.
if targets.sum() < 0: targets.zero_()
2020-09-01 16:14:45 +08:00
outputs = self.heads(features, targets)
2021-01-18 11:36:38 +08:00
losses = self.losses(outputs, targets)
return losses
2020-07-06 16:57:43 +08:00
else:
2020-09-01 16:14:45 +08:00
outputs = self.heads(features)
return outputs
2020-03-25 10:58:26 +08:00
def preprocess_image(self, batched_inputs):
"""
Normalize and batch the input images.
"""
if isinstance(batched_inputs, dict):
images = batched_inputs["images"].to(self.device)
elif isinstance(batched_inputs, torch.Tensor):
images = batched_inputs.to(self.device)
2020-09-01 16:14:45 +08:00
else:
raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs)))
images.sub_(self.pixel_mean).div_(self.pixel_std)
return images
2021-01-18 11:36:38 +08:00
def losses(self, outputs, gt_labels):
"""
2020-07-10 16:26:35 +08:00
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
2020-09-01 16:14:45 +08:00
# model predictions
2021-01-18 11:36:38 +08:00
# fmt: off
2020-09-01 16:14:45 +08:00
pred_class_logits = outputs['pred_class_logits'].detach()
cls_outputs = outputs['cls_outputs']
pred_features = outputs['features']
# fmt: on
2020-07-06 16:57:43 +08:00
2020-07-14 11:58:06 +08:00
# Log prediction accuracy
2020-09-01 16:14:45 +08:00
log_accuracy(pred_class_logits, gt_labels)
loss_dict = {}
loss_names = self._cfg.MODEL.LOSSES.NAME
2020-07-14 11:58:06 +08:00
2020-07-06 16:57:43 +08:00
if "CrossEntropyLoss" in loss_names:
loss_dict["loss_cls"] = cross_entropy_loss(
2020-09-01 16:14:45 +08:00
cls_outputs,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE
2020-07-06 16:57:43 +08:00
if "TripletLoss" in loss_names:
loss_dict["loss_triplet"] = triplet_loss(
2020-09-01 16:14:45 +08:00
pred_features,
gt_labels,
self._cfg.MODEL.LOSSES.TRI.MARGIN,
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
) * self._cfg.MODEL.LOSSES.TRI.SCALE
2020-07-06 16:57:43 +08:00
if "CircleLoss" in loss_names:
loss_dict["loss_circle"] = pairwise_circleloss(
2020-09-01 16:14:45 +08:00
pred_features,
gt_labels,
self._cfg.MODEL.LOSSES.CIRCLE.MARGIN,
self._cfg.MODEL.LOSSES.CIRCLE.GAMMA,
2020-09-01 16:14:45 +08:00
) * self._cfg.MODEL.LOSSES.CIRCLE.SCALE
if "Cosface" in loss_names:
loss_dict["loss_cosface"] = pairwise_cosface(
pred_features,
gt_labels,
self._cfg.MODEL.LOSSES.COSFACE.MARGIN,
self._cfg.MODEL.LOSSES.COSFACE.GAMMA,
) * self._cfg.MODEL.LOSSES.COSFACE.SCALE
2020-07-06 16:57:43 +08:00
return loss_dict