mmyolo/projects/assigner_visualization/dense_heads/yolov5_head_assigner.py

189 lines
8.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Union
import torch
from mmdet.models.utils import unpack_gt_instances
from mmengine.structures import InstanceData
from torch import Tensor
from mmyolo.models import YOLOv5Head
from mmyolo.registry import MODELS
@MODELS.register_module()
class YOLOv5HeadAssigner(YOLOv5Head):
def assign_by_gt_and_feat(
self,
batch_gt_instances: Sequence[InstanceData],
batch_img_metas: Sequence[dict],
inputs_hw: Union[Tensor, tuple] = (640, 640)
) -> dict:
"""Calculate the assigning results based on the gt and features
extracted by the detection head.
Args:
batch_gt_instances (Sequence[InstanceData]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (Sequence[dict]): Meta information of each image,
e.g., image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
inputs_hw (Union[Tensor, tuple]): Height and width of inputs size.
Returns:
dict[str, Tensor]: A dictionary of assigning results.
"""
# 1. Convert gt to norm format
batch_targets_normed = self._convert_gt_to_norm_format(
batch_gt_instances, batch_img_metas)
device = batch_targets_normed.device
scaled_factor = torch.ones(7, device=device)
gt_inds = torch.arange(
batch_targets_normed.shape[1],
dtype=torch.long,
device=device,
requires_grad=False).unsqueeze(0).repeat((self.num_base_priors, 1))
assign_results = []
for i in range(self.num_levels):
assign_results_feat = []
h = inputs_hw[0] // self.featmap_strides[i]
w = inputs_hw[1] // self.featmap_strides[i]
# empty gt bboxes
if batch_targets_normed.shape[1] == 0:
for k in range(self.num_base_priors):
assign_results_feat.append({
'stride':
self.featmap_strides[i],
'grid_x_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'grid_y_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'img_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'class_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'retained_gt_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'prior_ind':
k
})
assign_results.append(assign_results_feat)
continue
priors_base_sizes_i = self.priors_base_sizes[i]
# feature map scale whwh
scaled_factor[2:6] = torch.tensor([w, h, w, h])
# Scale batch_targets from range 0-1 to range 0-features_maps size.
# (num_base_priors, num_bboxes, 7)
batch_targets_scaled = batch_targets_normed * scaled_factor
# 2. Shape match
wh_ratio = batch_targets_scaled[...,
4:6] / priors_base_sizes_i[:, None]
match_inds = torch.max(
wh_ratio, 1 / wh_ratio).max(2)[0] < self.prior_match_thr
batch_targets_scaled = batch_targets_scaled[match_inds]
match_gt_inds = gt_inds[match_inds]
# no gt bbox matches anchor
if batch_targets_scaled.shape[0] == 0:
for k in range(self.num_base_priors):
assign_results_feat.append({
'stride':
self.featmap_strides[i],
'grid_x_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'grid_y_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'img_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'class_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'retained_gt_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'prior_ind':
k
})
assign_results.append(assign_results_feat)
continue
# 3. Positive samples with additional neighbors
# check the left, up, right, bottom sides of the
# targets grid, and determine whether assigned
# them as positive samples as well.
batch_targets_cxcy = batch_targets_scaled[:, 2:4]
grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy
left, up = ((batch_targets_cxcy % 1 < self.near_neighbor_thr) &
(batch_targets_cxcy > 1)).T
right, bottom = ((grid_xy % 1 < self.near_neighbor_thr) &
(grid_xy > 1)).T
offset_inds = torch.stack(
(torch.ones_like(left), left, up, right, bottom))
batch_targets_scaled = batch_targets_scaled.repeat(
(5, 1, 1))[offset_inds]
retained_gt_inds = match_gt_inds.repeat((5, 1))[offset_inds]
retained_offsets = self.grid_offset.repeat(1, offset_inds.shape[1],
1)[offset_inds]
# prepare pred results and positive sample indexes to
# calculate class loss and bbox lo
_chunk_targets = batch_targets_scaled.chunk(4, 1)
img_class_inds, grid_xy, grid_wh, priors_inds = _chunk_targets
priors_inds, (img_inds, class_inds) = priors_inds.long().view(
-1), img_class_inds.long().T
grid_xy_long = (grid_xy -
retained_offsets * self.near_neighbor_thr).long()
grid_x_inds, grid_y_inds = grid_xy_long.T
for k in range(self.num_base_priors):
retained_inds = priors_inds == k
assign_results_prior = {
'stride': self.featmap_strides[i],
'grid_x_inds': grid_x_inds[retained_inds],
'grid_y_inds': grid_y_inds[retained_inds],
'img_inds': img_inds[retained_inds],
'class_inds': class_inds[retained_inds],
'retained_gt_inds': retained_gt_inds[retained_inds],
'prior_ind': k
}
assign_results_feat.append(assign_results_prior)
assign_results.append(assign_results_feat)
return assign_results
def assign(self, batch_data_samples: Union[list, dict],
inputs_hw: Union[tuple, torch.Size]) -> dict:
"""Calculate assigning results. This function is provided to the
`assigner_visualization.py` script.
Args:
batch_data_samples (List[:obj:`DetDataSample`], dict): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
inputs_hw: Height and width of inputs size
Returns:
dict: A dictionary of assigning components.
"""
if isinstance(batch_data_samples, list):
outputs = unpack_gt_instances(batch_data_samples)
(batch_gt_instances, batch_gt_instances_ignore,
batch_img_metas) = outputs
assign_inputs = (batch_gt_instances, batch_img_metas,
batch_gt_instances_ignore, inputs_hw)
else:
# Fast version
assign_inputs = (batch_data_samples['bboxes_labels'],
batch_data_samples['img_metas'], inputs_hw)
assign_results = self.assign_by_gt_and_feat(*assign_inputs)
return assign_results