tuofeilun 654554cf65
support obj365 (#242)
Support Objects365 pretrain and Adding the DINO++ model can achieve an accuracy of 63.4mAP at a model scale of 200M(Under the same scale, the accuracy is the best)
2022-12-02 14:33:01 +08:00

573 lines
24 KiB
Python

# Copyright (c) 2022 IDEA. All Rights Reserved.
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from easycv.framework.errors import NotImplementedError
from easycv.models.builder import HEADS, build_neck
from easycv.models.detection.utils import (DetrPostProcess, box_xyxy_to_cxcywh,
inverse_sigmoid)
from easycv.models.loss import CDNCriterion, HungarianMatcher, SetCriterion
from easycv.models.utils import MLP
from easycv.utils.dist_utils import get_dist_info, is_dist_available
from ..dab_detr.dab_detr_transformer import PositionEmbeddingSineHW
from .cdn_components import cdn_post_process, prepare_for_cdn
@HEADS.register_module()
class DINOHead(nn.Module):
""" Initializes the DINO Head.
See `paper: DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection
<https://arxiv.org/abs/2203.03605>`_ for details.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_classes: number of object classes
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
fix_refpoints_hw: -1(default): learn w and h for each box seperately
>0 : given fixed number
-2 : learn a shared w and h
"""
def __init__(
self,
num_classes,
embed_dims,
in_channels=[512, 1024, 2048],
query_dim=4,
num_queries=300,
num_select=300,
random_refpoints_xy=False,
num_patterns=0,
dn_components=None,
transformer=None,
fix_refpoints_hw=-1,
num_feature_levels=1,
# two stage
two_stage_type='standard', # ['no', 'standard']
two_stage_add_query_num=0,
dec_pred_class_embed_share=True,
dec_pred_bbox_embed_share=True,
two_stage_class_embed_share=True,
two_stage_bbox_embed_share=True,
use_centerness=False,
use_iouaware=False,
losses_list=['labels', 'boxes'],
decoder_sa_type='sa',
temperatureH=20,
temperatureW=20,
cost_dict={
'cost_class': 1,
'cost_bbox': 5,
'cost_giou': 2,
},
weight_dict={
'loss_ce': 1,
'loss_bbox': 5,
'loss_giou': 2
},
**kwargs):
super(DINOHead, self).__init__()
self.matcher = HungarianMatcher(
cost_dict=cost_dict, cost_class_type='focal_loss_cost')
self.criterion = SetCriterion(
num_classes,
matcher=self.matcher,
weight_dict=weight_dict,
losses=losses_list,
loss_class_type='focal_loss')
if dn_components is not None:
self.dn_criterion = CDNCriterion(
num_classes,
matcher=self.matcher,
weight_dict=weight_dict,
losses=losses_list,
loss_class_type='focal_loss')
self.postprocess = DetrPostProcess(
num_select=num_select,
use_centerness=use_centerness,
use_iouaware=use_iouaware)
self.transformer = build_neck(transformer)
self.positional_encoding = PositionEmbeddingSineHW(
embed_dims // 2,
temperatureH=temperatureH,
temperatureW=temperatureW,
normalize=True)
self.num_classes = num_classes
self.num_queries = num_queries
self.embed_dims = embed_dims
self.query_dim = query_dim
self.dn_components = dn_components
self.random_refpoints_xy = random_refpoints_xy
self.fix_refpoints_hw = fix_refpoints_hw
# for dn training
self.dn_number = self.dn_components['dn_number']
self.dn_box_noise_scale = self.dn_components['dn_box_noise_scale']
self.dn_label_noise_ratio = self.dn_components['dn_label_noise_ratio']
self.dn_labelbook_size = self.dn_components['dn_labelbook_size']
self.label_enc = nn.Embedding(self.dn_labelbook_size + 1, embed_dims)
# prepare input projection layers
self.num_feature_levels = num_feature_levels
if num_feature_levels > 1:
num_backbone_outs = len(in_channels)
input_proj_list = []
for i in range(num_backbone_outs):
in_channels_i = in_channels[i]
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels_i, embed_dims, kernel_size=1),
nn.GroupNorm(32, embed_dims),
))
for _ in range(num_feature_levels - num_backbone_outs):
input_proj_list.append(
nn.Sequential(
nn.Conv2d(
in_channels_i,
embed_dims,
kernel_size=3,
stride=2,
padding=1),
nn.GroupNorm(32, embed_dims),
))
in_channels_i = embed_dims
self.input_proj = nn.ModuleList(input_proj_list)
else:
assert two_stage_type == 'no', 'two_stage_type should be no if num_feature_levels=1 !!!'
self.input_proj = nn.ModuleList([
nn.Sequential(
nn.Conv2d(in_channels[-1], embed_dims, kernel_size=1),
nn.GroupNorm(32, embed_dims),
)
])
# prepare pred layers
self.dec_pred_class_embed_share = dec_pred_class_embed_share
self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
# prepare class & box embed
_class_embed = nn.Linear(embed_dims, num_classes)
_bbox_embed = MLP(embed_dims, embed_dims, 4, 3)
# init the two embed layers
prior_prob = 0.01
bias_value = -math.log((1 - prior_prob) / prior_prob)
_class_embed.bias.data = torch.ones(self.num_classes) * bias_value
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
# fcos centerness & iou-aware & tokenlabel
self.use_centerness = use_centerness
self.use_iouaware = use_iouaware
if self.use_centerness:
_center_embed = MLP(embed_dims, embed_dims, 1, 3)
if self.use_iouaware:
_iou_embed = MLP(embed_dims, embed_dims, 1, 3)
if dec_pred_bbox_embed_share:
box_embed_layerlist = [
_bbox_embed for i in range(transformer.num_decoder_layers)
]
if self.use_centerness:
center_embed_layerlist = [
_center_embed
for i in range(transformer.num_decoder_layers)
]
if self.use_iouaware:
iou_embed_layerlist = [
_iou_embed for i in range(transformer.num_decoder_layers)
]
else:
box_embed_layerlist = [
copy.deepcopy(_bbox_embed)
for i in range(transformer.num_decoder_layers)
]
if self.use_centerness:
center_embed_layerlist = [
copy.deepcopy(_center_embed)
for i in range(transformer.num_decoder_layers)
]
if self.use_iouaware:
iou_embed_layerlist = [
copy.deepcopy(_iou_embed)
for i in range(transformer.num_decoder_layers)
]
if dec_pred_class_embed_share:
class_embed_layerlist = [
_class_embed for i in range(transformer.num_decoder_layers)
]
else:
class_embed_layerlist = [
copy.deepcopy(_class_embed)
for i in range(transformer.num_decoder_layers)
]
self.bbox_embed = nn.ModuleList(box_embed_layerlist)
self.class_embed = nn.ModuleList(class_embed_layerlist)
self.transformer.decoder.bbox_embed = self.bbox_embed
self.transformer.decoder.class_embed = self.class_embed
if self.use_centerness:
self.center_embed = nn.ModuleList(center_embed_layerlist)
self.transformer.decoder.center_embed = self.center_embed
if self.use_iouaware:
self.iou_embed = nn.ModuleList(iou_embed_layerlist)
self.transformer.decoder.iou_embed = self.iou_embed
# two stage
self.two_stage_type = two_stage_type
self.two_stage_add_query_num = two_stage_add_query_num
assert two_stage_type in [
'no', 'standard'
], 'unknown param {} of two_stage_type'.format(two_stage_type)
if two_stage_type != 'no':
if two_stage_bbox_embed_share:
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
self.transformer.enc_out_bbox_embed = _bbox_embed
if self.use_centerness:
self.transformer.enc_out_center_embed = _center_embed
if self.use_iouaware:
self.transformer.enc_out_iou_embed = _iou_embed
else:
self.transformer.enc_out_bbox_embed = copy.deepcopy(
_bbox_embed)
if self.use_centerness:
self.transformer.enc_out_center_embed = copy.deepcopy(
_center_embed)
if self.use_iouaware:
self.transformer.enc_out_iou_embed = copy.deepcopy(
_iou_embed)
if two_stage_class_embed_share:
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
self.transformer.enc_out_class_embed = _class_embed
else:
self.transformer.enc_out_class_embed = copy.deepcopy(
_class_embed)
self.refpoint_embed = None
if self.two_stage_add_query_num > 0:
self.init_ref_points(two_stage_add_query_num)
self.decoder_sa_type = decoder_sa_type
assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']
# self.replace_sa_with_double_ca = replace_sa_with_double_ca
if decoder_sa_type == 'ca_label':
self.label_embedding = nn.Embedding(num_classes, embed_dims)
for layer in self.transformer.decoder.layers:
layer.label_embedding = self.label_embedding
else:
for layer in self.transformer.decoder.layers:
layer.label_embedding = None
self.label_embedding = None
def init_weights(self):
# init input_proj
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
if self.random_refpoints_xy:
# import ipdb; ipdb.set_trace()
self.refpoint_embed.weight.data[:, :2].uniform_(0, 1)
self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(
self.refpoint_embed.weight.data[:, :2])
self.refpoint_embed.weight.data[:, :2].requires_grad = False
if self.fix_refpoints_hw > 0:
print('fix_refpoints_hw: {}'.format(self.fix_refpoints_hw))
assert self.random_refpoints_xy
self.refpoint_embed.weight.data[:, 2:] = self.fix_refpoints_hw
self.refpoint_embed.weight.data[:, 2:] = inverse_sigmoid(
self.refpoint_embed.weight.data[:, 2:])
self.refpoint_embed.weight.data[:, 2:].requires_grad = False
elif int(self.fix_refpoints_hw) == -1:
pass
elif int(self.fix_refpoints_hw) == -2:
print('learn a shared h and w')
assert self.random_refpoints_xy
self.refpoint_embed = nn.Embedding(use_num_queries, 2)
self.refpoint_embed.weight.data[:, :2].uniform_(0, 1)
self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(
self.refpoint_embed.weight.data[:, :2])
self.refpoint_embed.weight.data[:, :2].requires_grad = False
self.hw_embed = nn.Embedding(1, 1)
else:
raise NotImplementedError('Unknown fix_refpoints_hw {}'.format(
self.fix_refpoints_hw))
def prepare(self, features, targets=None, mode='train'):
if self.dn_number > 0 and targets is not None:
input_query_label, input_query_bbox, attn_mask, dn_meta =\
prepare_for_cdn(dn_args=(targets, self.dn_number, self.dn_label_noise_ratio, self.dn_box_noise_scale), num_queries=self.num_queries, num_classes=self.num_classes,
hidden_dim=self.embed_dims, label_enc=self.label_enc)
else:
assert targets is None
input_query_bbox = input_query_label = attn_mask = dn_meta = None
return input_query_bbox, input_query_label, attn_mask, dn_meta
def forward(self,
feats,
img_metas,
query_embed=None,
tgt=None,
attn_mask=None,
dn_meta=None):
"""Forward function.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
img_metas (list[dict]): List of image information.
Returns:
tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels.
- all_cls_scores_list (list[Tensor]): Classification scores \
for each scale level. Each is a 4D-tensor with shape \
[nb_dec, bs, num_query, cls_out_channels]. Note \
`cls_out_channels` should includes background.
- all_bbox_preds_list (list[Tensor]): Sigmoid regression \
outputs for each scale level. Each is a 4D-tensor with \
normalized coordinate format (cx, cy, w, h) and shape \
[nb_dec, bs, num_query, 4].
"""
# construct binary masks which used for the transformer.
# NOTE following the official DETR repo, non-zero values representing
# ignored positions, while zero values means valid positions.
bs = feats[0].size(0)
input_img_h, input_img_w = img_metas[0]['batch_input_shape']
img_masks = feats[0].new_ones((bs, input_img_h, input_img_w))
for img_id in range(bs):
img_h, img_w, _ = img_metas[img_id]['img_shape']
img_masks[img_id, :img_h, :img_w] = 0
srcs = []
masks = []
poss = []
for l, src in enumerate(feats):
mask = F.interpolate(
img_masks[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
# position encoding
pos_l = self.positional_encoding(mask) # [bs, embed_dim, h, w]
srcs.append(self.input_proj[l](src))
masks.append(mask)
poss.append(pos_l)
assert mask is not None
if self.num_feature_levels > len(srcs):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs:
src = self.input_proj[l](feats[-1])
else:
src = self.input_proj[l](srcs[-1])
mask = F.interpolate(
img_masks[None].float(),
size=src.shape[-2:]).to(torch.bool)[0]
# position encoding
pos_l = self.positional_encoding(mask) # [bs, embed_dim, h, w]
srcs.append(src)
masks.append(mask)
poss.append(pos_l)
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
srcs, masks, query_embed, poss, tgt, attn_mask)
# In case num object=0
hs[0] += self.label_enc.weight[0, 0] * 0.0
# deformable-detr-like anchor update
# reference_before_sigmoid = inverse_sigmoid(reference[:-1]) # n_dec, bs, nq, 4
outputs_coord_list = []
for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
zip(reference[:-1], self.bbox_embed, hs)):
layer_delta_unsig = layer_bbox_embed(layer_hs)
layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(
layer_ref_sig)
layer_outputs_unsig = layer_outputs_unsig.sigmoid()
outputs_coord_list.append(layer_outputs_unsig)
outputs_coord_list = torch.stack(outputs_coord_list)
# outputs_class = self.class_embed(hs)
outputs_class = torch.stack([
layer_cls_embed(layer_hs)
for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
])
outputs_center_list = None
if self.use_centerness:
outputs_center_list = torch.stack([
layer_center_embed(layer_hs)
for layer_center_embed, layer_hs in zip(self.center_embed, hs)
])
outputs_iou_list = None
if self.use_iouaware:
outputs_iou_list = torch.stack([
layer_iou_embed(layer_hs)
for layer_iou_embed, layer_hs in zip(self.iou_embed, hs)
])
reference = torch.stack(reference)[:-1][..., :2]
if self.dn_number > 0 and dn_meta is not None:
outputs_class, outputs_coord_list, outputs_center_list, outputs_iou_list, reference = cdn_post_process(
outputs_class, outputs_coord_list, dn_meta, self._set_aux_loss,
outputs_center_list, outputs_iou_list, reference)
out = {
'pred_logits':
outputs_class[-1],
'pred_boxes':
outputs_coord_list[-1],
'pred_centers':
outputs_center_list[-1]
if outputs_center_list is not None else None,
'pred_ious':
outputs_iou_list[-1] if outputs_iou_list is not None else None,
'refpts':
reference[-1],
}
out['aux_outputs'] = self._set_aux_loss(outputs_class,
outputs_coord_list,
outputs_center_list,
outputs_iou_list, reference)
# for encoder output
if hs_enc is not None:
# prepare intermediate outputs
interm_coord = ref_enc[-1]
interm_class = self.transformer.enc_out_class_embed(hs_enc[-1])
if self.use_centerness:
interm_center = self.transformer.enc_out_center_embed(
hs_enc[-1])
if self.use_iouaware:
interm_iou = self.transformer.enc_out_iou_embed(hs_enc[-1])
out['interm_outputs'] = {
'pred_logits': interm_class,
'pred_boxes': interm_coord,
'pred_centers': interm_center if self.use_centerness else None,
'pred_ious': interm_iou if self.use_iouaware else None,
'refpts': init_box_proposal[..., :2],
}
out['dn_meta'] = dn_meta
return out
@torch.jit.unused
def _set_aux_loss(self,
outputs_class,
outputs_coord,
outputs_center=None,
outputs_iou=None,
reference=None):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [{
'pred_logits':
a,
'pred_boxes':
b,
'pred_centers':
outputs_center[i] if outputs_center is not None else None,
'pred_ious':
outputs_iou[i] if outputs_iou is not None else None,
'refpts':
reference[i],
} for i, (a,
b) in enumerate(zip(outputs_class[:-1], outputs_coord[:-1]))]
# over-write because img_metas are needed as inputs for bbox_head.
def forward_train(self, x, img_metas, gt_bboxes, gt_labels):
"""Forward function for training mode.
Args:
x (list[Tensor]): Features from backbone.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes (Tensor): Ground truth bboxes of the image,
shape (num_gts, 4).
gt_labels (Tensor): Ground truth labels of each box,
shape (num_gts,).
gt_bboxes_ignore (Tensor): Ground truth bboxes to be
ignored, shape (num_ignored_gts, 4).
proposal_cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
# prepare ground truth
for i in range(len(img_metas)):
img_h, img_w, _ = img_metas[i]['img_shape']
# DETR regress the relative position of boxes (cxcywh) in the image.
# Thus the learning target should be normalized by the image size, also
# the box format should be converted from defaultly x1y1x2y2 to cxcywh.
factor = gt_bboxes[i].new_tensor([img_w, img_h, img_w,
img_h]).unsqueeze(0)
gt_bboxes[i] = box_xyxy_to_cxcywh(gt_bboxes[i]) / factor
targets = []
for gt_label, gt_bbox in zip(gt_labels, gt_bboxes):
targets.append({'labels': gt_label, 'boxes': gt_bbox})
query_embed, tgt, attn_mask, dn_meta = self.prepare(
x, targets=targets, mode='train')
outputs = self.forward(
x,
img_metas,
query_embed=query_embed,
tgt=tgt,
attn_mask=attn_mask,
dn_meta=dn_meta)
# Avoid inconsistent num_boxes for set_critertion and dn_critertion
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t['labels']) for t in targets)
num_boxes = torch.as_tensor([num_boxes],
dtype=torch.float,
device=next(iter(outputs.values())).device)
if is_dist_available():
torch.distributed.all_reduce(num_boxes)
_, world_size = get_dist_info()
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
losses = self.criterion(outputs, targets, num_boxes=num_boxes)
losses.update(
self.dn_criterion(outputs, targets, len(outputs['aux_outputs']),
num_boxes))
return losses
def forward_test(self, x, img_metas):
query_embed, tgt, attn_mask, dn_meta = self.prepare(x, mode='test')
outputs = self.forward(
x,
img_metas,
query_embed=query_embed,
tgt=tgt,
attn_mask=attn_mask,
dn_meta=dn_meta)
ori_shape_list = []
for i in range(len(img_metas)):
ori_h, ori_w, _ = img_metas[i]['ori_shape']
ori_shape_list.append(torch.as_tensor([ori_h, ori_w]))
orig_target_sizes = torch.stack(ori_shape_list, dim=0)
results = self.postprocess(outputs, orig_target_sizes, img_metas)
return results