tuofeilun 654554cf65
support obj365 (#242)
Support Objects365 pretrain and Adding the DINO++ model can achieve an accuracy of 63.4mAP at a model scale of 200M(Under the same scale, the accuracy is the best)
2022-12-02 14:33:01 +08:00

188 lines
7.9 KiB
Python

# ------------------------------------------------------------------------
# DN-DETR
# Copyright (c) 2022 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
import torch
from easycv.models.detection.utils import inverse_sigmoid
def prepare_for_cdn(dn_args, num_queries, num_classes, hidden_dim, label_enc):
"""
A major difference of DINO from DN-DETR is that the author process pattern embedding pattern embedding in its detector
forward function and use learnable tgt embedding, so we change this function a little bit.
:param dn_args: targets, dn_number, label_noise_ratio, box_noise_scale
:param num_queries: number of queires
:param num_classes: number of classes
:param hidden_dim: transformer hidden dim
:param label_enc: encode labels in dn
:return:
"""
targets, dn_number, label_noise_ratio, box_noise_scale = dn_args
# positive and negative dn queries
dn_number = dn_number * 2
known = [(torch.ones_like(t['labels'])).cuda() for t in targets]
batch_size = len(known)
known_num = [sum(k) for k in known]
if int(max(known_num)) == 0:
dn_number = 1
else:
if dn_number >= 100:
dn_number = dn_number // (int(max(known_num) * 2))
elif dn_number < 1:
dn_number = 1
if dn_number == 0:
dn_number = 1
unmask_bbox = unmask_label = torch.cat(known)
labels = torch.cat([t['labels'] for t in targets])
boxes = torch.cat([t['boxes'] for t in targets])
batch_idx = torch.cat([
torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)
])
known_indice = torch.nonzero(unmask_label + unmask_bbox)
known_indice = known_indice.view(-1)
known_indice = known_indice.repeat(2 * dn_number, 1).view(-1)
known_labels = labels.repeat(2 * dn_number, 1).view(-1)
known_bid = batch_idx.repeat(2 * dn_number, 1).view(-1)
known_bboxs = boxes.repeat(2 * dn_number, 1)
known_labels_expaned = known_labels.clone()
known_bbox_expand = known_bboxs.clone()
if label_noise_ratio > 0:
p = torch.rand_like(known_labels_expaned.float())
chosen_indice = torch.nonzero(p < (label_noise_ratio)).view(
-1) # half of bbox prob
new_label = torch.randint_like(
chosen_indice, 0, num_classes) # randomly put a new one here
known_labels_expaned.scatter_(0, chosen_indice, new_label)
single_pad = int(max(known_num))
pad_size = int(single_pad * 2 * dn_number)
positive_idx = torch.tensor(range(
len(boxes))).long().cuda().unsqueeze(0).repeat(dn_number, 1)
positive_idx += (torch.tensor(range(dn_number)) * len(boxes) *
2).long().cuda().unsqueeze(1)
positive_idx = positive_idx.flatten()
negative_idx = positive_idx + len(boxes)
if box_noise_scale > 0:
known_bbox_ = torch.zeros_like(known_bboxs)
known_bbox_[:, :2] = known_bboxs[:, :2] - known_bboxs[:, 2:] / 2
known_bbox_[:, 2:] = known_bboxs[:, :2] + known_bboxs[:, 2:] / 2
diff = torch.zeros_like(known_bboxs)
diff[:, :2] = known_bboxs[:, 2:] / 2
diff[:, 2:] = known_bboxs[:, 2:] / 2
rand_sign = torch.randint_like(
known_bboxs, low=0, high=2, dtype=torch.float32) * 2.0 - 1.0
rand_part = torch.rand_like(known_bboxs)
rand_part[negative_idx] += 1.0
rand_part *= rand_sign
known_bbox_ = known_bbox_ + torch.mul(rand_part,
diff).cuda() * box_noise_scale
known_bbox_ = known_bbox_.clamp(min=0.0, max=1.0)
known_bbox_expand[:, :2] = (known_bbox_[:, :2] +
known_bbox_[:, 2:]) / 2
known_bbox_expand[:, 2:] = known_bbox_[:, 2:] - known_bbox_[:, :2]
m = known_labels_expaned.long().to('cuda')
input_label_embed = label_enc(m)
input_bbox_embed = inverse_sigmoid(known_bbox_expand)
padding_label = torch.zeros(pad_size, hidden_dim).cuda()
padding_bbox = torch.zeros(pad_size, 4).cuda()
input_query_label = padding_label.repeat(batch_size, 1, 1)
input_query_bbox = padding_bbox.repeat(batch_size, 1, 1)
map_known_indice = torch.tensor([]).to('cuda')
if len(known_num):
map_known_indice = torch.cat(
[torch.tensor(range(num)) for num in known_num]) # [1,2, 1,2,3]
map_known_indice = torch.cat([
map_known_indice + single_pad * i for i in range(2 * dn_number)
]).long()
if len(known_bid):
input_query_label[(known_bid.long(),
map_known_indice)] = input_label_embed
input_query_bbox[(known_bid.long(),
map_known_indice)] = input_bbox_embed
tgt_size = pad_size + num_queries
attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0
# match query cannot see the reconstruct
attn_mask[pad_size:, :pad_size] = True
# reconstruct query cannot see the match
attn_mask[:pad_size, pad_size:] = True
# reconstruct cannot see each other
for i in range(dn_number):
if i == 0:
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1),
single_pad * 2 * (i + 1):pad_size] = True
if i == dn_number - 1:
attn_mask[single_pad * 2 * i:single_pad * 2 *
(i + 1), :single_pad * i * 2] = True
else:
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1),
single_pad * 2 * (i + 1):pad_size] = True
attn_mask[single_pad * 2 * i:single_pad * 2 *
(i + 1), :single_pad * 2 * i] = True
dn_meta = {
'pad_size': pad_size,
'num_dn_group': dn_number,
}
return input_query_label, input_query_bbox, attn_mask, dn_meta
def cdn_post_process(outputs_class,
outputs_coord,
dn_meta,
_set_aux_loss,
outputs_center=None,
outputs_iou=None,
reference=None):
"""
post process of dn after output from the transformer
put the dn part in the dn_meta
"""
if dn_meta and dn_meta['pad_size'] > 0:
output_known_class = outputs_class[:, :, :dn_meta['pad_size'], :]
output_known_coord = outputs_coord[:, :, :dn_meta['pad_size'], :]
outputs_class = outputs_class[:, :, dn_meta['pad_size']:, :]
outputs_coord = outputs_coord[:, :, dn_meta['pad_size']:, :]
output_known_center = None
output_known_iou = None
if outputs_center is not None:
output_known_center = outputs_center[:, :, :dn_meta['pad_size'], :]
outputs_center = outputs_center[:, :, dn_meta['pad_size']:, :]
if outputs_iou is not None:
output_known_iou = outputs_iou[:, :, :dn_meta['pad_size'], :]
outputs_iou = outputs_iou[:, :, dn_meta['pad_size']:, :]
known_reference = reference[:, :, :dn_meta['pad_size'], :]
reference = reference[:, :, dn_meta['pad_size']:, :]
out = {
'pred_logits':
output_known_class[-1],
'pred_boxes':
output_known_coord[-1],
'pred_centers':
output_known_center[-1]
if output_known_center is not None else None,
'pred_ious':
output_known_iou[-1] if output_known_iou is not None else None,
'refpts':
known_reference[-1],
}
out['aux_outputs'] = _set_aux_loss(output_known_class,
output_known_coord,
output_known_center,
output_known_iou, known_reference)
dn_meta['output_known_lbs_bboxes'] = out
return outputs_class, outputs_coord, outputs_center, outputs_iou, reference