# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Alibaba, Inc. and its affiliates. import torch import torch.nn as nn from easycv.models.builder import HEADS, build_neck from easycv.models.detection.utils import DetrPostProcess, box_xyxy_to_cxcywh from easycv.models.loss import HungarianMatcher, SetCriterion from easycv.models.utils import MLP @HEADS.register_module() class DETRHead(nn.Module): """Implements the DETR transformer head. See `paper: End-to-End Object Detection with Transformers `_ for details. Args: num_classes (int): Number of categories excluding the background. """ _version = 2 def __init__(self, num_classes, embed_dims, eos_coef=0.1, transformer=None, cost_dict={ 'cost_class': 1, 'cost_bbox': 5, 'cost_giou': 2, }, weight_dict={ 'loss_ce': 1, 'loss_bbox': 5, 'loss_giou': 2 }, **kwargs): super(DETRHead, self).__init__() self.matcher = HungarianMatcher(cost_dict=cost_dict) self.criterion = SetCriterion( num_classes, matcher=self.matcher, weight_dict=weight_dict, eos_coef=eos_coef, losses=['labels', 'boxes']) self.postprocess = DetrPostProcess() self.transformer = build_neck(transformer) self.class_embed = nn.Linear(embed_dims, num_classes + 1) self.bbox_embed = MLP(embed_dims, embed_dims, 4, 3) self.num_classes = num_classes def init_weights(self): """Initialize weights of the detr head.""" self.transformer.init_weights() def forward(self, feats, img_metas): """Forward function. Args: feats (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. img_metas (list[dict]): List of image information. Returns: tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels. - all_cls_scores_list (list[Tensor]): Classification scores \ for each scale level. Each is a 4D-tensor with shape \ [nb_dec, bs, num_query, cls_out_channels]. Note \ `cls_out_channels` should includes background. - all_bbox_preds_list (list[Tensor]): Sigmoid regression \ outputs for each scale level. Each is a 4D-tensor with \ normalized coordinate format (cx, cy, w, h) and shape \ [nb_dec, bs, num_query, 4]. """ feats = self.transformer(feats, img_metas) outputs_class = self.class_embed(feats) outputs_coord = self.bbox_embed(feats).sigmoid() out = { 'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1] } out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) return out @torch.jit.unused def _set_aux_loss(self, outputs_class, outputs_coord): # this is a workaround to make torchscript happy, as torchscript # doesn't support dictionary with non-homogeneous values, such # as a dict having both a Tensor and a list. return [{ 'pred_logits': a, 'pred_boxes': b } for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] # over-write because img_metas are needed as inputs for bbox_head. def forward_train(self, x, img_metas, gt_bboxes, gt_labels): """Forward function for training mode. Args: x (list[Tensor]): Features from backbone. img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. gt_bboxes (Tensor): Ground truth bboxes of the image, shape (num_gts, 4). gt_labels (Tensor): Ground truth labels of each box, shape (num_gts,). gt_bboxes_ignore (Tensor): Ground truth bboxes to be ignored, shape (num_ignored_gts, 4). proposal_cfg (mmcv.Config): Test / postprocessing configuration, if None, test_cfg would be used. Returns: dict[str, Tensor]: A dictionary of loss components. """ # prepare ground truth for i in range(len(img_metas)): img_h, img_w, _ = img_metas[i]['img_shape'] # DETR regress the relative position of boxes (cxcywh) in the image. # Thus the learning target should be normalized by the image size, also # the box format should be converted from defaultly x1y1x2y2 to cxcywh. factor = gt_bboxes[i].new_tensor([img_w, img_h, img_w, img_h]).unsqueeze(0) gt_bboxes[i] = box_xyxy_to_cxcywh(gt_bboxes[i]) / factor targets = [] for gt_label, gt_bbox in zip(gt_labels, gt_bboxes): targets.append({'labels': gt_label, 'boxes': gt_bbox}) outputs = self.forward(x, img_metas) losses = self.criterion(outputs, targets) return losses def forward_test(self, x, img_metas): outputs = self.forward(x, img_metas) ori_shape_list = [] for i in range(len(img_metas)): ori_h, ori_w, _ = img_metas[i]['ori_shape'] ori_shape_list.append(torch.as_tensor([ori_h, ori_w])) orig_target_sizes = torch.stack(ori_shape_list, dim=0) results = self.postprocess(outputs, orig_target_sizes, img_metas) return results