mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
359 lines
15 KiB
Python
359 lines
15 KiB
Python
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
from easycv.models.builder import HEADS, build_neck
|
||
|
from easycv.models.detection.utils import (HungarianMatcher, accuracy,
|
||
|
box_cxcywh_to_xyxy,
|
||
|
box_xyxy_to_cxcywh,
|
||
|
generalized_box_iou)
|
||
|
from easycv.models.utils import (MLP, get_world_size,
|
||
|
is_dist_avail_and_initialized)
|
||
|
|
||
|
|
||
|
@HEADS.register_module()
|
||
|
class DETRHead(nn.Module):
|
||
|
"""Implements the DETR transformer head.
|
||
|
See `paper: End-to-End Object Detection with Transformers
|
||
|
<https://arxiv.org/pdf/2005.12872>`_ for details.
|
||
|
Args:
|
||
|
num_classes (int): Number of categories excluding the background.
|
||
|
"""
|
||
|
|
||
|
_version = 2
|
||
|
|
||
|
def __init__(self,
|
||
|
num_classes,
|
||
|
embed_dims,
|
||
|
eos_coef=0.1,
|
||
|
transformer=None,
|
||
|
cost_dict={
|
||
|
'cost_class': 1,
|
||
|
'cost_bbox': 5,
|
||
|
'cost_giou': 2,
|
||
|
},
|
||
|
weight_dict={
|
||
|
'loss_ce': 1,
|
||
|
'loss_bbox': 5,
|
||
|
'loss_giou': 2
|
||
|
},
|
||
|
**kwargs):
|
||
|
|
||
|
super(DETRHead, self).__init__()
|
||
|
|
||
|
self.matcher = HungarianMatcher(cost_dict=cost_dict)
|
||
|
self.criterion = SetCriterion(
|
||
|
num_classes,
|
||
|
matcher=self.matcher,
|
||
|
weight_dict=weight_dict,
|
||
|
eos_coef=eos_coef,
|
||
|
losses=['labels', 'boxes', 'cardinality'])
|
||
|
self.postprocess = PostProcess()
|
||
|
self.transformer = build_neck(transformer)
|
||
|
|
||
|
self.class_embed = nn.Linear(embed_dims, num_classes + 1)
|
||
|
self.bbox_embed = MLP(embed_dims, embed_dims, 4, 3)
|
||
|
self.num_classes = num_classes
|
||
|
|
||
|
def init_weights(self):
|
||
|
"""Initialize weights of the detr head."""
|
||
|
self.transformer.init_weights()
|
||
|
|
||
|
def forward(self, feats, img_metas):
|
||
|
"""Forward function.
|
||
|
Args:
|
||
|
feats (tuple[Tensor]): Features from the upstream network, each is
|
||
|
a 4D-tensor.
|
||
|
img_metas (list[dict]): List of image information.
|
||
|
Returns:
|
||
|
tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels.
|
||
|
- all_cls_scores_list (list[Tensor]): Classification scores \
|
||
|
for each scale level. Each is a 4D-tensor with shape \
|
||
|
[nb_dec, bs, num_query, cls_out_channels]. Note \
|
||
|
`cls_out_channels` should includes background.
|
||
|
- all_bbox_preds_list (list[Tensor]): Sigmoid regression \
|
||
|
outputs for each scale level. Each is a 4D-tensor with \
|
||
|
normalized coordinate format (cx, cy, w, h) and shape \
|
||
|
[nb_dec, bs, num_query, 4].
|
||
|
"""
|
||
|
feats = self.transformer(feats, img_metas)
|
||
|
|
||
|
outputs_class = self.class_embed(feats)
|
||
|
outputs_coord = self.bbox_embed(feats).sigmoid()
|
||
|
|
||
|
out = {
|
||
|
'pred_logits': outputs_class[-1],
|
||
|
'pred_boxes': outputs_coord[-1]
|
||
|
}
|
||
|
|
||
|
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
|
||
|
return out
|
||
|
|
||
|
@torch.jit.unused
|
||
|
def _set_aux_loss(self, outputs_class, outputs_coord):
|
||
|
# this is a workaround to make torchscript happy, as torchscript
|
||
|
# doesn't support dictionary with non-homogeneous values, such
|
||
|
# as a dict having both a Tensor and a list.
|
||
|
return [{
|
||
|
'pred_logits': a,
|
||
|
'pred_boxes': b
|
||
|
} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
||
|
|
||
|
# over-write because img_metas are needed as inputs for bbox_head.
|
||
|
def forward_train(self, x, img_metas, gt_bboxes, gt_labels):
|
||
|
"""Forward function for training mode.
|
||
|
Args:
|
||
|
x (list[Tensor]): Features from backbone.
|
||
|
img_metas (list[dict]): Meta information of each image, e.g.,
|
||
|
image size, scaling factor, etc.
|
||
|
gt_bboxes (Tensor): Ground truth bboxes of the image,
|
||
|
shape (num_gts, 4).
|
||
|
gt_labels (Tensor): Ground truth labels of each box,
|
||
|
shape (num_gts,).
|
||
|
gt_bboxes_ignore (Tensor): Ground truth bboxes to be
|
||
|
ignored, shape (num_ignored_gts, 4).
|
||
|
proposal_cfg (mmcv.Config): Test / postprocessing configuration,
|
||
|
if None, test_cfg would be used.
|
||
|
Returns:
|
||
|
dict[str, Tensor]: A dictionary of loss components.
|
||
|
"""
|
||
|
outputs = self.forward(x, img_metas)
|
||
|
|
||
|
for i in range(len(img_metas)):
|
||
|
img_h, img_w, _ = img_metas[i]['img_shape']
|
||
|
# DETR regress the relative position of boxes (cxcywh) in the image.
|
||
|
# Thus the learning target should be normalized by the image size, also
|
||
|
# the box format should be converted from defaultly x1y1x2y2 to cxcywh.
|
||
|
factor = outputs['pred_boxes'].new_tensor(
|
||
|
[img_w, img_h, img_w, img_h]).unsqueeze(0)
|
||
|
gt_bboxes[i] = box_xyxy_to_cxcywh(gt_bboxes[i]) / factor
|
||
|
|
||
|
targets = []
|
||
|
for gt_label, gt_bbox in zip(gt_labels, gt_bboxes):
|
||
|
targets.append({'labels': gt_label, 'boxes': gt_bbox})
|
||
|
|
||
|
losses = self.criterion(outputs, targets)
|
||
|
|
||
|
return losses
|
||
|
|
||
|
def forward_test(self, x, img_metas):
|
||
|
outputs = self.forward(x, img_metas)
|
||
|
|
||
|
ori_shape_list = []
|
||
|
for i in range(len(img_metas)):
|
||
|
ori_h, ori_w, _ = img_metas[i]['ori_shape']
|
||
|
ori_shape_list.append(torch.as_tensor([ori_h, ori_w]))
|
||
|
orig_target_sizes = torch.stack(ori_shape_list, dim=0)
|
||
|
|
||
|
results = self.postprocess(outputs, orig_target_sizes, img_metas)
|
||
|
return results
|
||
|
|
||
|
|
||
|
class PostProcess(nn.Module):
|
||
|
""" This module converts the model's output into the format expected by the coco api"""
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def forward(self, outputs, target_sizes, img_metas):
|
||
|
""" Perform the computation
|
||
|
Parameters:
|
||
|
outputs: raw outputs of the model
|
||
|
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
|
||
|
For evaluation, this must be the original image size (before any data augmentation)
|
||
|
For visualization, this should be the image size after data augment, but before padding
|
||
|
"""
|
||
|
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
|
||
|
|
||
|
assert len(out_logits) == len(target_sizes)
|
||
|
assert target_sizes.shape[1] == 2
|
||
|
|
||
|
prob = F.softmax(out_logits, -1)
|
||
|
scores, labels = prob[..., :-1].max(-1)
|
||
|
|
||
|
# convert to [x0, y0, x1, y1] format
|
||
|
boxes = box_cxcywh_to_xyxy(out_bbox)
|
||
|
# and from relative [0, 1] to absolute [0, height] coordinates
|
||
|
img_h, img_w = target_sizes.unbind(1)
|
||
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h],
|
||
|
dim=1).to(boxes.device)
|
||
|
boxes = boxes * scale_fct[:, None, :]
|
||
|
|
||
|
results = {
|
||
|
'detection_boxes': [boxes[0].cpu().numpy()],
|
||
|
'detection_scores': [scores[0].cpu().numpy()],
|
||
|
'detection_classes': [labels[0].cpu().numpy().astype(np.int32)],
|
||
|
'img_metas': img_metas
|
||
|
}
|
||
|
|
||
|
return results
|
||
|
|
||
|
|
||
|
class SetCriterion(nn.Module):
|
||
|
""" This class computes the loss for DETR.
|
||
|
The process happens in two steps:
|
||
|
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
||
|
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
||
|
"""
|
||
|
|
||
|
def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
|
||
|
""" Create the criterion.
|
||
|
Parameters:
|
||
|
num_classes: number of object categories, omitting the special no-object category
|
||
|
matcher: module able to compute a matching between targets and proposals
|
||
|
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
||
|
eos_coef: relative classification weight applied to the no-object category
|
||
|
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
||
|
"""
|
||
|
super().__init__()
|
||
|
self.num_classes = num_classes
|
||
|
self.matcher = matcher
|
||
|
self.weight_dict = weight_dict
|
||
|
self.eos_coef = eos_coef
|
||
|
self.losses = losses
|
||
|
empty_weight = torch.ones(self.num_classes + 1)
|
||
|
empty_weight[-1] = self.eos_coef
|
||
|
self.register_buffer('empty_weight', empty_weight)
|
||
|
|
||
|
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
|
||
|
"""Classification loss (NLL)
|
||
|
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
||
|
"""
|
||
|
assert 'pred_logits' in outputs
|
||
|
src_logits = outputs['pred_logits']
|
||
|
|
||
|
idx = self._get_src_permutation_idx(indices)
|
||
|
target_classes_o = torch.cat(
|
||
|
[t['labels'][J] for t, (_, J) in zip(targets, indices)])
|
||
|
target_classes = torch.full(
|
||
|
src_logits.shape[:2],
|
||
|
self.num_classes,
|
||
|
dtype=torch.int64,
|
||
|
device=src_logits.device)
|
||
|
target_classes[idx] = target_classes_o
|
||
|
|
||
|
loss_ce = F.cross_entropy(
|
||
|
src_logits.transpose(1, 2), target_classes,
|
||
|
self.empty_weight) * self.weight_dict['loss_ce']
|
||
|
losses = {'loss_ce': loss_ce}
|
||
|
|
||
|
if log:
|
||
|
# TODO this should probably be a separate loss, not hacked in this one here
|
||
|
losses['class_error'] = 100 - accuracy(src_logits[idx],
|
||
|
target_classes_o)[0]
|
||
|
return losses
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
||
|
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
|
||
|
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
|
||
|
"""
|
||
|
pred_logits = outputs['pred_logits']
|
||
|
device = pred_logits.device
|
||
|
tgt_lengths = torch.as_tensor([len(v['labels']) for v in targets],
|
||
|
device=device)
|
||
|
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
||
|
card_pred = (pred_logits.argmax(-1) !=
|
||
|
pred_logits.shape[-1] - 1).sum(1)
|
||
|
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
|
||
|
losses = {'cardinality_error': card_err}
|
||
|
return losses
|
||
|
|
||
|
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
||
|
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
|
||
|
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
|
||
|
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
|
||
|
"""
|
||
|
assert 'pred_boxes' in outputs
|
||
|
idx = self._get_src_permutation_idx(indices)
|
||
|
src_boxes = outputs['pred_boxes'][idx]
|
||
|
target_boxes = torch.cat(
|
||
|
[t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||
|
|
||
|
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
|
||
|
|
||
|
losses = {}
|
||
|
losses['loss_bbox'] = loss_bbox.sum(
|
||
|
) / num_boxes * self.weight_dict['loss_bbox']
|
||
|
|
||
|
loss_giou = 1 - torch.diag(
|
||
|
generalized_box_iou(
|
||
|
box_cxcywh_to_xyxy(src_boxes),
|
||
|
box_cxcywh_to_xyxy(target_boxes)))
|
||
|
losses['loss_giou'] = loss_giou.sum(
|
||
|
) / num_boxes * self.weight_dict['loss_giou']
|
||
|
return losses
|
||
|
|
||
|
def _get_src_permutation_idx(self, indices):
|
||
|
# permute predictions following indices
|
||
|
batch_idx = torch.cat(
|
||
|
[torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
||
|
src_idx = torch.cat([src for (src, _) in indices])
|
||
|
return batch_idx, src_idx
|
||
|
|
||
|
def _get_tgt_permutation_idx(self, indices):
|
||
|
# permute targets following indices
|
||
|
batch_idx = torch.cat(
|
||
|
[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
||
|
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
||
|
return batch_idx, tgt_idx
|
||
|
|
||
|
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
|
||
|
loss_map = {
|
||
|
'labels': self.loss_labels,
|
||
|
'cardinality': self.loss_cardinality,
|
||
|
'boxes': self.loss_boxes
|
||
|
}
|
||
|
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
||
|
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
|
||
|
|
||
|
def forward(self, outputs, targets):
|
||
|
""" This performs the loss computation.
|
||
|
Parameters:
|
||
|
outputs: dict of tensors, see the output specification of the model for the format
|
||
|
targets: list of dicts, such that len(targets) == batch_size.
|
||
|
The expected keys in each dict depends on the losses applied, see each loss' doc
|
||
|
"""
|
||
|
outputs_without_aux = {
|
||
|
k: v
|
||
|
for k, v in outputs.items() if k != 'aux_outputs'
|
||
|
}
|
||
|
|
||
|
# Retrieve the matching between the outputs of the last layer and the targets
|
||
|
indices = self.matcher(outputs_without_aux, targets)
|
||
|
|
||
|
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
||
|
num_boxes = sum(len(t['labels']) for t in targets)
|
||
|
num_boxes = torch.as_tensor([num_boxes],
|
||
|
dtype=torch.float,
|
||
|
device=next(iter(outputs.values())).device)
|
||
|
if is_dist_avail_and_initialized():
|
||
|
torch.distributed.all_reduce(num_boxes)
|
||
|
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
|
||
|
|
||
|
# Compute all the requested losses
|
||
|
losses = {}
|
||
|
for loss in self.losses:
|
||
|
losses.update(
|
||
|
self.get_loss(loss, outputs, targets, indices, num_boxes))
|
||
|
|
||
|
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
||
|
if 'aux_outputs' in outputs:
|
||
|
for i, aux_outputs in enumerate(outputs['aux_outputs']):
|
||
|
indices = self.matcher(aux_outputs, targets)
|
||
|
for loss in self.losses:
|
||
|
if loss == 'masks':
|
||
|
# Intermediate masks losses are too costly to compute, we ignore them.
|
||
|
continue
|
||
|
kwargs = {}
|
||
|
if loss == 'labels':
|
||
|
# Logging is enabled only for the last layer
|
||
|
kwargs = {'log': False}
|
||
|
l_dict = self.get_loss(loss, aux_outputs, targets, indices,
|
||
|
num_boxes, **kwargs)
|
||
|
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
||
|
losses.update(l_dict)
|
||
|
|
||
|
return losses
|