mirror of https://github.com/open-mmlab/mmyolo.git
189 lines
8.2 KiB
Python
189 lines
8.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Sequence, Union
|
|
|
|
import torch
|
|
from mmdet.models.utils import unpack_gt_instances
|
|
from mmengine.structures import InstanceData
|
|
from torch import Tensor
|
|
|
|
from mmyolo.models import YOLOv5Head
|
|
from mmyolo.registry import MODELS
|
|
|
|
|
|
@MODELS.register_module()
|
|
class YOLOv5HeadAssigner(YOLOv5Head):
|
|
|
|
def assign_by_gt_and_feat(
|
|
self,
|
|
batch_gt_instances: Sequence[InstanceData],
|
|
batch_img_metas: Sequence[dict],
|
|
inputs_hw: Union[Tensor, tuple] = (640, 640)
|
|
) -> dict:
|
|
"""Calculate the assigning results based on the gt and features
|
|
extracted by the detection head.
|
|
|
|
Args:
|
|
batch_gt_instances (Sequence[InstanceData]): Batch of
|
|
gt_instance. It usually includes ``bboxes`` and ``labels``
|
|
attributes.
|
|
batch_img_metas (Sequence[dict]): Meta information of each image,
|
|
e.g., image size, scaling factor, etc.
|
|
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
|
|
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
|
|
data that is ignored during training and testing.
|
|
Defaults to None.
|
|
inputs_hw (Union[Tensor, tuple]): Height and width of inputs size.
|
|
Returns:
|
|
dict[str, Tensor]: A dictionary of assigning results.
|
|
"""
|
|
# 1. Convert gt to norm format
|
|
batch_targets_normed = self._convert_gt_to_norm_format(
|
|
batch_gt_instances, batch_img_metas)
|
|
|
|
device = batch_targets_normed.device
|
|
scaled_factor = torch.ones(7, device=device)
|
|
gt_inds = torch.arange(
|
|
batch_targets_normed.shape[1],
|
|
dtype=torch.long,
|
|
device=device,
|
|
requires_grad=False).unsqueeze(0).repeat((self.num_base_priors, 1))
|
|
|
|
assign_results = []
|
|
for i in range(self.num_levels):
|
|
assign_results_feat = []
|
|
h = inputs_hw[0] // self.featmap_strides[i]
|
|
w = inputs_hw[1] // self.featmap_strides[i]
|
|
|
|
# empty gt bboxes
|
|
if batch_targets_normed.shape[1] == 0:
|
|
for k in range(self.num_base_priors):
|
|
assign_results_feat.append({
|
|
'stride':
|
|
self.featmap_strides[i],
|
|
'grid_x_inds':
|
|
torch.zeros([0], dtype=torch.int64).to(device),
|
|
'grid_y_inds':
|
|
torch.zeros([0], dtype=torch.int64).to(device),
|
|
'img_inds':
|
|
torch.zeros([0], dtype=torch.int64).to(device),
|
|
'class_inds':
|
|
torch.zeros([0], dtype=torch.int64).to(device),
|
|
'retained_gt_inds':
|
|
torch.zeros([0], dtype=torch.int64).to(device),
|
|
'prior_ind':
|
|
k
|
|
})
|
|
assign_results.append(assign_results_feat)
|
|
continue
|
|
|
|
priors_base_sizes_i = self.priors_base_sizes[i]
|
|
# feature map scale whwh
|
|
scaled_factor[2:6] = torch.tensor([w, h, w, h])
|
|
# Scale batch_targets from range 0-1 to range 0-features_maps size.
|
|
# (num_base_priors, num_bboxes, 7)
|
|
batch_targets_scaled = batch_targets_normed * scaled_factor
|
|
|
|
# 2. Shape match
|
|
wh_ratio = batch_targets_scaled[...,
|
|
4:6] / priors_base_sizes_i[:, None]
|
|
match_inds = torch.max(
|
|
wh_ratio, 1 / wh_ratio).max(2)[0] < self.prior_match_thr
|
|
batch_targets_scaled = batch_targets_scaled[match_inds]
|
|
match_gt_inds = gt_inds[match_inds]
|
|
|
|
# no gt bbox matches anchor
|
|
if batch_targets_scaled.shape[0] == 0:
|
|
for k in range(self.num_base_priors):
|
|
assign_results_feat.append({
|
|
'stride':
|
|
self.featmap_strides[i],
|
|
'grid_x_inds':
|
|
torch.zeros([0], dtype=torch.int64).to(device),
|
|
'grid_y_inds':
|
|
torch.zeros([0], dtype=torch.int64).to(device),
|
|
'img_inds':
|
|
torch.zeros([0], dtype=torch.int64).to(device),
|
|
'class_inds':
|
|
torch.zeros([0], dtype=torch.int64).to(device),
|
|
'retained_gt_inds':
|
|
torch.zeros([0], dtype=torch.int64).to(device),
|
|
'prior_ind':
|
|
k
|
|
})
|
|
assign_results.append(assign_results_feat)
|
|
continue
|
|
|
|
# 3. Positive samples with additional neighbors
|
|
|
|
# check the left, up, right, bottom sides of the
|
|
# targets grid, and determine whether assigned
|
|
# them as positive samples as well.
|
|
batch_targets_cxcy = batch_targets_scaled[:, 2:4]
|
|
grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy
|
|
left, up = ((batch_targets_cxcy % 1 < self.near_neighbor_thr) &
|
|
(batch_targets_cxcy > 1)).T
|
|
right, bottom = ((grid_xy % 1 < self.near_neighbor_thr) &
|
|
(grid_xy > 1)).T
|
|
offset_inds = torch.stack(
|
|
(torch.ones_like(left), left, up, right, bottom))
|
|
|
|
batch_targets_scaled = batch_targets_scaled.repeat(
|
|
(5, 1, 1))[offset_inds]
|
|
retained_gt_inds = match_gt_inds.repeat((5, 1))[offset_inds]
|
|
retained_offsets = self.grid_offset.repeat(1, offset_inds.shape[1],
|
|
1)[offset_inds]
|
|
|
|
# prepare pred results and positive sample indexes to
|
|
# calculate class loss and bbox lo
|
|
_chunk_targets = batch_targets_scaled.chunk(4, 1)
|
|
img_class_inds, grid_xy, grid_wh, priors_inds = _chunk_targets
|
|
priors_inds, (img_inds, class_inds) = priors_inds.long().view(
|
|
-1), img_class_inds.long().T
|
|
|
|
grid_xy_long = (grid_xy -
|
|
retained_offsets * self.near_neighbor_thr).long()
|
|
grid_x_inds, grid_y_inds = grid_xy_long.T
|
|
for k in range(self.num_base_priors):
|
|
retained_inds = priors_inds == k
|
|
assign_results_prior = {
|
|
'stride': self.featmap_strides[i],
|
|
'grid_x_inds': grid_x_inds[retained_inds],
|
|
'grid_y_inds': grid_y_inds[retained_inds],
|
|
'img_inds': img_inds[retained_inds],
|
|
'class_inds': class_inds[retained_inds],
|
|
'retained_gt_inds': retained_gt_inds[retained_inds],
|
|
'prior_ind': k
|
|
}
|
|
assign_results_feat.append(assign_results_prior)
|
|
assign_results.append(assign_results_feat)
|
|
return assign_results
|
|
|
|
def assign(self, batch_data_samples: Union[list, dict],
|
|
inputs_hw: Union[tuple, torch.Size]) -> dict:
|
|
"""Calculate assigning results. This function is provided to the
|
|
`assigner_visualization.py` script.
|
|
|
|
Args:
|
|
batch_data_samples (List[:obj:`DetDataSample`], dict): The Data
|
|
Samples. It usually includes information such as
|
|
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
|
|
inputs_hw: Height and width of inputs size
|
|
|
|
Returns:
|
|
dict: A dictionary of assigning components.
|
|
"""
|
|
if isinstance(batch_data_samples, list):
|
|
outputs = unpack_gt_instances(batch_data_samples)
|
|
(batch_gt_instances, batch_gt_instances_ignore,
|
|
batch_img_metas) = outputs
|
|
|
|
assign_inputs = (batch_gt_instances, batch_img_metas,
|
|
batch_gt_instances_ignore, inputs_hw)
|
|
else:
|
|
# Fast version
|
|
assign_inputs = (batch_data_samples['bboxes_labels'],
|
|
batch_data_samples['img_metas'], inputs_hw)
|
|
assign_results = self.assign_by_gt_and_feat(*assign_inputs)
|
|
|
|
return assign_results
|