mirror of https://github.com/open-mmlab/mmyolo.git
175 lines
7.3 KiB
Python
175 lines
7.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import List, Union
|
|
|
|
import torch
|
|
from mmdet.structures.bbox import distance2bbox
|
|
from mmdet.utils import InstanceList
|
|
from torch import Tensor
|
|
|
|
from mmyolo.models import RTMDetHead
|
|
from mmyolo.models.utils import gt_instances_preprocess
|
|
from mmyolo.registry import MODELS
|
|
|
|
|
|
@MODELS.register_module()
|
|
class RTMHeadAssigner(RTMDetHead):
|
|
|
|
def assign_by_gt_and_feat(
|
|
self,
|
|
cls_scores: List[Tensor],
|
|
bbox_preds: List[Tensor],
|
|
batch_gt_instances: InstanceList,
|
|
batch_img_metas: List[dict],
|
|
inputs_hw: Union[Tensor, tuple] = (640, 640)
|
|
) -> dict:
|
|
"""Calculate the assigning results based on the gt and features
|
|
extracted by the detection head.
|
|
|
|
Args:
|
|
cls_scores (list[Tensor]): Box scores for each scale level
|
|
Has shape (N, num_anchors * num_classes, H, W)
|
|
bbox_preds (list[Tensor]): Decoded box for each scale
|
|
level with shape (N, num_anchors * 4, H, W) in
|
|
[tl_x, tl_y, br_x, br_y] format.
|
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
|
gt_instance. It usually includes ``bboxes`` and ``labels``
|
|
attributes.
|
|
batch_img_metas (list[dict]): Meta information of each image, e.g.,
|
|
image size, scaling factor, etc.
|
|
inputs_hw (Union[Tensor, tuple]): Height and width of inputs size.
|
|
Returns:
|
|
dict[str, Tensor]: A dictionary of assigning results.
|
|
"""
|
|
num_imgs = len(batch_img_metas)
|
|
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
|
|
assert len(featmap_sizes) == self.prior_generator.num_levels
|
|
prior_offset = self.prior_generator.offset
|
|
|
|
gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs)
|
|
gt_labels = gt_info[:, :, :1]
|
|
gt_bboxes = gt_info[:, :, 1:] # xyxy
|
|
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
|
|
|
|
device = cls_scores[0].device
|
|
|
|
# If the shape does not equal, generate new one
|
|
if featmap_sizes != self.featmap_sizes_train:
|
|
self.featmap_sizes_train = featmap_sizes
|
|
mlvl_priors_with_stride = self.prior_generator.grid_priors(
|
|
featmap_sizes, device=device, with_stride=True)
|
|
self.flatten_priors_train = torch.cat(
|
|
mlvl_priors_with_stride, dim=0)
|
|
|
|
flatten_cls_scores = torch.cat([
|
|
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
|
self.cls_out_channels)
|
|
for cls_score in cls_scores
|
|
], 1).contiguous()
|
|
|
|
flatten_bboxes = torch.cat([
|
|
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
|
for bbox_pred in bbox_preds
|
|
], 1)
|
|
flatten_bboxes = flatten_bboxes * self.flatten_priors_train[..., -1,
|
|
None]
|
|
flatten_bboxes = distance2bbox(self.flatten_priors_train[..., :2],
|
|
flatten_bboxes)
|
|
|
|
assigned_result = self.assigner(flatten_bboxes.detach(),
|
|
flatten_cls_scores.detach(),
|
|
self.flatten_priors_train, gt_labels,
|
|
gt_bboxes, pad_bbox_flag)
|
|
|
|
labels = assigned_result['assigned_labels'].reshape(-1)
|
|
bbox_targets = assigned_result['assigned_bboxes'].reshape(-1, 4)
|
|
|
|
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
|
|
bg_class_ind = self.num_classes
|
|
pos_inds = ((labels >= 0)
|
|
& (labels < bg_class_ind)).nonzero().squeeze(1)
|
|
targets = bbox_targets[pos_inds]
|
|
gt_bboxes = gt_bboxes.squeeze(0)
|
|
matched_gt_inds = torch.tensor(
|
|
[((t == gt_bboxes).sum(dim=1) == t.shape[0]).nonzero()[0]
|
|
for t in targets],
|
|
device=device)
|
|
|
|
level_inds = torch.zeros_like(labels)
|
|
img_inds = torch.zeros_like(labels)
|
|
level_nums = [0] + [f[0] * f[1] for f in featmap_sizes]
|
|
for i in range(len(level_nums) - 1):
|
|
level_nums[i + 1] = level_nums[i] + level_nums[i + 1]
|
|
level_inds[level_nums[i]:level_nums[i + 1]] = i
|
|
level_inds_pos = level_inds[pos_inds]
|
|
|
|
img_inds = img_inds[pos_inds]
|
|
labels = labels[pos_inds]
|
|
|
|
inputs_hw = batch_img_metas[0]['batch_input_shape']
|
|
assign_results = []
|
|
for i in range(self.num_levels):
|
|
retained_inds = level_inds_pos == i
|
|
if not retained_inds.any():
|
|
assign_results_prior = {
|
|
'stride':
|
|
self.featmap_strides[i],
|
|
'grid_x_inds':
|
|
torch.zeros([0], dtype=torch.int64).to(device),
|
|
'grid_y_inds':
|
|
torch.zeros([0], dtype=torch.int64).to(device),
|
|
'img_inds':
|
|
torch.zeros([0], dtype=torch.int64).to(device),
|
|
'class_inds':
|
|
torch.zeros([0], dtype=torch.int64).to(device),
|
|
'retained_gt_inds':
|
|
torch.zeros([0], dtype=torch.int64).to(device),
|
|
'prior_ind':
|
|
0,
|
|
'offset':
|
|
prior_offset
|
|
}
|
|
else:
|
|
w = inputs_hw[1] // self.featmap_strides[i]
|
|
|
|
retained_pos_inds = pos_inds[retained_inds] - level_nums[i]
|
|
grid_y_inds = retained_pos_inds // w
|
|
grid_x_inds = retained_pos_inds - retained_pos_inds // w * w
|
|
assign_results_prior = {
|
|
'stride': self.featmap_strides[i],
|
|
'grid_x_inds': grid_x_inds,
|
|
'grid_y_inds': grid_y_inds,
|
|
'img_inds': img_inds[retained_inds],
|
|
'class_inds': labels[retained_inds],
|
|
'retained_gt_inds': matched_gt_inds[retained_inds],
|
|
'prior_ind': 0,
|
|
'offset': prior_offset
|
|
}
|
|
assign_results.append([assign_results_prior])
|
|
return assign_results
|
|
|
|
def assign(self, batch_data_samples: Union[list, dict],
|
|
inputs_hw: Union[tuple, torch.Size]) -> dict:
|
|
"""Calculate assigning results. This function is provided to the
|
|
`assigner_visualization.py` script.
|
|
|
|
Args:
|
|
batch_data_samples (List[:obj:`DetDataSample`], dict): The Data
|
|
Samples. It usually includes information such as
|
|
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
|
|
inputs_hw: Height and width of inputs size
|
|
|
|
Returns:
|
|
dict: A dictionary of assigning components.
|
|
"""
|
|
if isinstance(batch_data_samples, list):
|
|
raise NotImplementedError(
|
|
'assigning results_list is not implemented')
|
|
else:
|
|
# Fast version
|
|
cls_scores, bbox_preds = self(batch_data_samples['feats'])
|
|
assign_inputs = (cls_scores, bbox_preds,
|
|
batch_data_samples['bboxes_labels'],
|
|
batch_data_samples['img_metas'], inputs_hw)
|
|
assign_results = self.assign_by_gt_and_feat(*assign_inputs)
|
|
return assign_results
|