# ------------------------------------------------------------------------ # DN-DETR # Copyright (c) 2022 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ import torch from easycv.models.detection.utils import inverse_sigmoid def prepare_for_cdn(dn_args, num_queries, num_classes, hidden_dim, label_enc): """ A major difference of DINO from DN-DETR is that the author process pattern embedding pattern embedding in its detector forward function and use learnable tgt embedding, so we change this function a little bit. :param dn_args: targets, dn_number, label_noise_ratio, box_noise_scale :param num_queries: number of queires :param num_classes: number of classes :param hidden_dim: transformer hidden dim :param label_enc: encode labels in dn :return: """ targets, dn_number, label_noise_ratio, box_noise_scale = dn_args # positive and negative dn queries dn_number = dn_number * 2 known = [(torch.ones_like(t['labels'])).cuda() for t in targets] batch_size = len(known) known_num = [sum(k) for k in known] if int(max(known_num)) == 0: dn_number = 1 else: if dn_number >= 100: dn_number = dn_number // (int(max(known_num) * 2)) elif dn_number < 1: dn_number = 1 if dn_number == 0: dn_number = 1 unmask_bbox = unmask_label = torch.cat(known) labels = torch.cat([t['labels'] for t in targets]) boxes = torch.cat([t['boxes'] for t in targets]) batch_idx = torch.cat([ torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets) ]) known_indice = torch.nonzero(unmask_label + unmask_bbox) known_indice = known_indice.view(-1) known_indice = known_indice.repeat(2 * dn_number, 1).view(-1) known_labels = labels.repeat(2 * dn_number, 1).view(-1) known_bid = batch_idx.repeat(2 * dn_number, 1).view(-1) known_bboxs = boxes.repeat(2 * dn_number, 1) known_labels_expaned = known_labels.clone() known_bbox_expand = known_bboxs.clone() if label_noise_ratio > 0: p = torch.rand_like(known_labels_expaned.float()) chosen_indice = torch.nonzero(p < (label_noise_ratio)).view( -1) # half of bbox prob new_label = torch.randint_like( chosen_indice, 0, num_classes) # randomly put a new one here known_labels_expaned.scatter_(0, chosen_indice, new_label) single_pad = int(max(known_num)) pad_size = int(single_pad * 2 * dn_number) positive_idx = torch.tensor(range( len(boxes))).long().cuda().unsqueeze(0).repeat(dn_number, 1) positive_idx += (torch.tensor(range(dn_number)) * len(boxes) * 2).long().cuda().unsqueeze(1) positive_idx = positive_idx.flatten() negative_idx = positive_idx + len(boxes) if box_noise_scale > 0: known_bbox_ = torch.zeros_like(known_bboxs) known_bbox_[:, :2] = known_bboxs[:, :2] - known_bboxs[:, 2:] / 2 known_bbox_[:, 2:] = known_bboxs[:, :2] + known_bboxs[:, 2:] / 2 diff = torch.zeros_like(known_bboxs) diff[:, :2] = known_bboxs[:, 2:] / 2 diff[:, 2:] = known_bboxs[:, 2:] / 2 rand_sign = torch.randint_like( known_bboxs, low=0, high=2, dtype=torch.float32) * 2.0 - 1.0 rand_part = torch.rand_like(known_bboxs) rand_part[negative_idx] += 1.0 rand_part *= rand_sign known_bbox_ = known_bbox_ + torch.mul(rand_part, diff).cuda() * box_noise_scale known_bbox_ = known_bbox_.clamp(min=0.0, max=1.0) known_bbox_expand[:, :2] = (known_bbox_[:, :2] + known_bbox_[:, 2:]) / 2 known_bbox_expand[:, 2:] = known_bbox_[:, 2:] - known_bbox_[:, :2] m = known_labels_expaned.long().to('cuda') input_label_embed = label_enc(m) input_bbox_embed = inverse_sigmoid(known_bbox_expand) padding_label = torch.zeros(pad_size, hidden_dim).cuda() padding_bbox = torch.zeros(pad_size, 4).cuda() input_query_label = padding_label.repeat(batch_size, 1, 1) input_query_bbox = padding_bbox.repeat(batch_size, 1, 1) map_known_indice = torch.tensor([]).to('cuda') if len(known_num): map_known_indice = torch.cat( [torch.tensor(range(num)) for num in known_num]) # [1,2, 1,2,3] map_known_indice = torch.cat([ map_known_indice + single_pad * i for i in range(2 * dn_number) ]).long() if len(known_bid): input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed tgt_size = pad_size + num_queries attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0 # match query cannot see the reconstruct attn_mask[pad_size:, :pad_size] = True # reconstruct query cannot see the match attn_mask[:pad_size, pad_size:] = True # reconstruct cannot see each other for i in range(dn_number): if i == 0: attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True if i == dn_number - 1: attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * i * 2] = True else: attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * 2 * i] = True dn_meta = { 'pad_size': pad_size, 'num_dn_group': dn_number, } return input_query_label, input_query_bbox, attn_mask, dn_meta def cdn_post_process(outputs_class, outputs_coord, dn_meta, _set_aux_loss, outputs_center=None, outputs_iou=None, reference=None): """ post process of dn after output from the transformer put the dn part in the dn_meta """ if dn_meta and dn_meta['pad_size'] > 0: output_known_class = outputs_class[:, :, :dn_meta['pad_size'], :] output_known_coord = outputs_coord[:, :, :dn_meta['pad_size'], :] outputs_class = outputs_class[:, :, dn_meta['pad_size']:, :] outputs_coord = outputs_coord[:, :, dn_meta['pad_size']:, :] output_known_center = None output_known_iou = None if outputs_center is not None: output_known_center = outputs_center[:, :, :dn_meta['pad_size'], :] outputs_center = outputs_center[:, :, dn_meta['pad_size']:, :] if outputs_iou is not None: output_known_iou = outputs_iou[:, :, :dn_meta['pad_size'], :] outputs_iou = outputs_iou[:, :, dn_meta['pad_size']:, :] known_reference = reference[:, :, :dn_meta['pad_size'], :] reference = reference[:, :, dn_meta['pad_size']:, :] out = { 'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1], 'pred_centers': output_known_center[-1] if output_known_center is not None else None, 'pred_ious': output_known_iou[-1] if output_known_iou is not None else None, 'refpts': known_reference[-1], } out['aux_outputs'] = _set_aux_loss(output_known_class, output_known_coord, output_known_center, output_known_iou, known_reference) dn_meta['output_known_lbs_bboxes'] = out return outputs_class, outputs_coord, outputs_center, outputs_iou, reference