mmyolo/projects/assigner_visualization/dense_heads/rtmdet_head_assigner.py

175 lines
7.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union
import torch
from mmdet.structures.bbox import distance2bbox
from mmdet.utils import InstanceList
from torch import Tensor
from mmyolo.models import RTMDetHead
from mmyolo.models.utils import gt_instances_preprocess
from mmyolo.registry import MODELS
@MODELS.register_module()
class RTMHeadAssigner(RTMDetHead):
def assign_by_gt_and_feat(
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
inputs_hw: Union[Tensor, tuple] = (640, 640)
) -> dict:
"""Calculate the assigning results based on the gt and features
extracted by the detection head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Decoded box for each scale
level with shape (N, num_anchors * 4, H, W) in
[tl_x, tl_y, br_x, br_y] format.
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
inputs_hw (Union[Tensor, tuple]): Height and width of inputs size.
Returns:
dict[str, Tensor]: A dictionary of assigning results.
"""
num_imgs = len(batch_img_metas)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.prior_generator.num_levels
prior_offset = self.prior_generator.offset
gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs)
gt_labels = gt_info[:, :, :1]
gt_bboxes = gt_info[:, :, 1:] # xyxy
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
device = cls_scores[0].device
# If the shape does not equal, generate new one
if featmap_sizes != self.featmap_sizes_train:
self.featmap_sizes_train = featmap_sizes
mlvl_priors_with_stride = self.prior_generator.grid_priors(
featmap_sizes, device=device, with_stride=True)
self.flatten_priors_train = torch.cat(
mlvl_priors_with_stride, dim=0)
flatten_cls_scores = torch.cat([
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.cls_out_channels)
for cls_score in cls_scores
], 1).contiguous()
flatten_bboxes = torch.cat([
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
], 1)
flatten_bboxes = flatten_bboxes * self.flatten_priors_train[..., -1,
None]
flatten_bboxes = distance2bbox(self.flatten_priors_train[..., :2],
flatten_bboxes)
assigned_result = self.assigner(flatten_bboxes.detach(),
flatten_cls_scores.detach(),
self.flatten_priors_train, gt_labels,
gt_bboxes, pad_bbox_flag)
labels = assigned_result['assigned_labels'].reshape(-1)
bbox_targets = assigned_result['assigned_bboxes'].reshape(-1, 4)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().squeeze(1)
targets = bbox_targets[pos_inds]
gt_bboxes = gt_bboxes.squeeze(0)
matched_gt_inds = torch.tensor(
[((t == gt_bboxes).sum(dim=1) == t.shape[0]).nonzero()[0]
for t in targets],
device=device)
level_inds = torch.zeros_like(labels)
img_inds = torch.zeros_like(labels)
level_nums = [0] + [f[0] * f[1] for f in featmap_sizes]
for i in range(len(level_nums) - 1):
level_nums[i + 1] = level_nums[i] + level_nums[i + 1]
level_inds[level_nums[i]:level_nums[i + 1]] = i
level_inds_pos = level_inds[pos_inds]
img_inds = img_inds[pos_inds]
labels = labels[pos_inds]
inputs_hw = batch_img_metas[0]['batch_input_shape']
assign_results = []
for i in range(self.num_levels):
retained_inds = level_inds_pos == i
if not retained_inds.any():
assign_results_prior = {
'stride':
self.featmap_strides[i],
'grid_x_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'grid_y_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'img_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'class_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'retained_gt_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'prior_ind':
0,
'offset':
prior_offset
}
else:
w = inputs_hw[1] // self.featmap_strides[i]
retained_pos_inds = pos_inds[retained_inds] - level_nums[i]
grid_y_inds = retained_pos_inds // w
grid_x_inds = retained_pos_inds - retained_pos_inds // w * w
assign_results_prior = {
'stride': self.featmap_strides[i],
'grid_x_inds': grid_x_inds,
'grid_y_inds': grid_y_inds,
'img_inds': img_inds[retained_inds],
'class_inds': labels[retained_inds],
'retained_gt_inds': matched_gt_inds[retained_inds],
'prior_ind': 0,
'offset': prior_offset
}
assign_results.append([assign_results_prior])
return assign_results
def assign(self, batch_data_samples: Union[list, dict],
inputs_hw: Union[tuple, torch.Size]) -> dict:
"""Calculate assigning results. This function is provided to the
`assigner_visualization.py` script.
Args:
batch_data_samples (List[:obj:`DetDataSample`], dict): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
inputs_hw: Height and width of inputs size
Returns:
dict: A dictionary of assigning components.
"""
if isinstance(batch_data_samples, list):
raise NotImplementedError(
'assigning results_list is not implemented')
else:
# Fast version
cls_scores, bbox_preds = self(batch_data_samples['feats'])
assign_inputs = (cls_scores, bbox_preds,
batch_data_samples['bboxes_labels'],
batch_data_samples['img_metas'], inputs_hw)
assign_results = self.assign_by_gt_and_feat(*assign_inputs)
return assign_results