RTMDet Assigner visualization ()

* fix format

* return multiple pos assigns

* rewrite to get matched_gt_inds

* ignore corrupted images

* rm RTMDetectorAssigner

* fix bug for different devices

* add warnings when use rtmdet without checkpoint

* add priors for rtmdet

* fix format

* add readme

* fix format

* fix readme and typo

* typo

* fix note
pull/551/head
yechenzhi 2023-02-13 11:42:11 +08:00 committed by GitHub
parent 164c319493
commit 75618020f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 231 additions and 13 deletions

View File

@ -6,7 +6,7 @@
This project is developed for easily showing assigning results. The script allows users to analyze where and how many positive samples each gt is assigned in the image.
Now, the script only support `YOLOv5` .
Now, the script supports `YOLOv5` and `RTMDet`.
## Usage
@ -15,3 +15,13 @@ Now, the script only support `YOLOv5` .
```shell
python projects/assigner_visualization/assigner_visualization.py projects/assigner_visualization/configs/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_assignervisualization.py
```
Note: `YOLOv5` does not need to load the trained weights.
```shell
python projects/assigner_visualization/assigner_visualization.py projects/assigner_visualization/configs/rtmdet_s_syncbn_fast_8xb32-300e_coco_assignervisualization.py -c ${checkpont}
```
${checkpont} is the checkpont file path. Dynamic label assignment is used in `RTMDet`, model weights will affect the positive sample allocation results, so it is recommended to load the trained model weights.
If you want to know details about label assignment, you can check the [documentation](https://mmyolo.readthedocs.io/zh_CN/latest/algorithm_descriptions/rtmdet_description.html#id5).

View File

@ -3,6 +3,7 @@ import argparse
import os
import os.path as osp
import sys
import warnings
import mmcv
import numpy as np
@ -10,11 +11,13 @@ import torch
from mmengine import ProgressBar
from mmengine.config import Config, DictAction
from mmengine.dataset import COLLATE_FUNCTIONS
from mmengine.runner import load_checkpoint
from numpy import random
from mmyolo.registry import DATASETS, MODELS
from mmyolo.utils import register_all_modules
from projects.assigner_visualization.dense_heads import YOLOv5HeadAssigner
from projects.assigner_visualization.dense_heads import (RTMHeadAssigner,
YOLOv5HeadAssigner)
from projects.assigner_visualization.visualization import \
YOLOAssignerVisualizer
@ -24,6 +27,7 @@ def parse_args():
description='MMYOLO show the positive sample assigning'
' results.')
parser.add_argument('config', help='config file path')
parser.add_argument('--checkpoint', '-c', type=str, help='checkpoint file')
parser.add_argument(
'--show-number',
'-n',
@ -82,11 +86,20 @@ def main():
# build model
model = MODELS.build(cfg.model)
assert isinstance(model.bbox_head, YOLOv5HeadAssigner),\
'Now, this script only support yolov5, and bbox_head must use ' \
'`YOLOv5HeadAssigner`. Please use `' \
if args.checkpoint is not None:
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
elif isinstance(model.bbox_head, RTMHeadAssigner):
warnings.warn(
'if you use dynamic_assignment methods such as yolov7 or '
'rtmdet assigner, please load the checkpoint.')
assert isinstance(model.bbox_head, (YOLOv5HeadAssigner, RTMHeadAssigner)),\
'Now, this script only support yolov5 and rtmdet, and ' \
'bbox_head must use ' \
'`YOLOv5HeadAssigner or RTMHeadAssigner`. Please use `' \
'yolov5_s-v61_syncbn_fast_8xb16-300e_coco_assignervisualization.py' \
'` as config file.'
'or rtmdet_s_syncbn_fast_8xb32-300e_coco_assignervisualization.py' \
"""` as config file."""
model.eval()
model.to(args.device)
@ -107,7 +120,9 @@ def main():
}], name='visualizer')
visualizer.dataset_meta = dataset.metainfo
# need priors size to draw priors
visualizer.priors_size = model.bbox_head.prior_generator.base_anchors
if hasattr(model.bbox_head.prior_generator, 'base_anchors'):
visualizer.priors_size = model.bbox_head.prior_generator.base_anchors
# make output dir
os.makedirs(args.output_dir, exist_ok=True)
@ -120,7 +135,10 @@ def main():
progress_bar = ProgressBar(display_number)
for ind_img in range(display_number):
data = dataset.prepare_data(ind_img)
if data is None:
print('Unable to visualize {} due to strong data augmentations'.
format(dataset[ind_img]['data_samples'].img_path))
continue
# convert data to batch format
batch_data = collate_fn([data])
with torch.no_grad():

View File

@ -0,0 +1,9 @@
_base_ = ['../../../configs/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco.py']
custom_imports = dict(imports=[
'projects.assigner_visualization.detectors',
'projects.assigner_visualization.dense_heads'
])
model = dict(
type='YOLODetectorAssigner', bbox_head=dict(type='RTMHeadAssigner'))

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .rtmdet_head_assigner import RTMHeadAssigner
from .yolov5_head_assigner import YOLOv5HeadAssigner
__all__ = ['YOLOv5HeadAssigner']
__all__ = ['YOLOv5HeadAssigner', 'RTMHeadAssigner']

View File

@ -0,0 +1,169 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union
import torch
from mmdet.structures.bbox import distance2bbox
from mmdet.utils import InstanceList
from torch import Tensor
from mmyolo.models import RTMDetHead
from mmyolo.registry import MODELS
@MODELS.register_module()
class RTMHeadAssigner(RTMDetHead):
def assign_by_gt_and_feat(
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
inputs_hw: Union[Tensor, tuple] = (640, 640)
) -> dict:
"""Calculate the assigning results based on the gt and features
extracted by the detection head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Decoded box for each scale
level with shape (N, num_anchors * 4, H, W) in
[tl_x, tl_y, br_x, br_y] format.
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
inputs_hw (Union[Tensor, tuple]): Height and width of inputs size.
Returns:
dict[str, Tensor]: A dictionary of assigning results.
"""
num_imgs = len(batch_img_metas)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.prior_generator.num_levels
gt_info = self.gt_instances_preprocess(batch_gt_instances, num_imgs)
gt_labels = gt_info[:, :, :1]
gt_bboxes = gt_info[:, :, 1:] # xyxy
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
device = cls_scores[0].device
# If the shape does not equal, generate new one
if featmap_sizes != self.featmap_sizes_train:
self.featmap_sizes_train = featmap_sizes
mlvl_priors_with_stride = self.prior_generator.grid_priors(
featmap_sizes, device=device, with_stride=True)
self.flatten_priors_train = torch.cat(
mlvl_priors_with_stride, dim=0)
flatten_cls_scores = torch.cat([
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.cls_out_channels)
for cls_score in cls_scores
], 1).contiguous()
flatten_bboxes = torch.cat([
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
], 1)
flatten_bboxes = flatten_bboxes * self.flatten_priors_train[..., -1,
None]
flatten_bboxes = distance2bbox(self.flatten_priors_train[..., :2],
flatten_bboxes)
assigned_result = self.assigner(flatten_bboxes.detach(),
flatten_cls_scores.detach(),
self.flatten_priors_train, gt_labels,
gt_bboxes, pad_bbox_flag)
labels = assigned_result['assigned_labels'].reshape(-1)
bbox_targets = assigned_result['assigned_bboxes'].reshape(-1, 4)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().squeeze(1)
targets = bbox_targets[pos_inds]
gt_bboxes = gt_bboxes.squeeze(0)
matched_gt_inds = torch.tensor(
[((t == gt_bboxes).sum(dim=1) == t.shape[0]).nonzero()[0]
for t in targets],
device=device)
level_inds = torch.zeros_like(labels)
img_inds = torch.zeros_like(labels)
level_nums = [0] + [f[0] * f[1] for f in featmap_sizes]
for i in range(len(level_nums) - 1):
level_nums[i + 1] = level_nums[i] + level_nums[i + 1]
level_inds[level_nums[i]:level_nums[i + 1]] = i
level_inds_pos = level_inds[pos_inds]
img_inds = img_inds[pos_inds]
labels = labels[pos_inds]
inputs_hw = batch_img_metas[0]['batch_input_shape']
assign_results = []
for i in range(self.num_levels):
retained_inds = level_inds_pos == i
if not retained_inds.any():
assign_results_prior = {
'stride':
self.featmap_strides[i],
'grid_x_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'grid_y_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'img_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'class_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'retained_gt_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'prior_ind':
0
}
else:
w = inputs_hw[1] // self.featmap_strides[i]
retained_pos_inds = pos_inds[retained_inds] - level_nums[i]
grid_y_inds = retained_pos_inds // w
grid_x_inds = retained_pos_inds - retained_pos_inds // w * w
assign_results_prior = {
'stride': self.featmap_strides[i],
'grid_x_inds': grid_x_inds,
'grid_y_inds': grid_y_inds,
'img_inds': img_inds[retained_inds],
'class_inds': labels[retained_inds],
'retained_gt_inds': matched_gt_inds[retained_inds],
'prior_ind': 0
}
assign_results.append([assign_results_prior])
return assign_results
def assign(self, batch_data_samples: Union[list, dict],
inputs_hw: Union[tuple, torch.Size]) -> dict:
"""Calculate assigning results. This function is provided to the
`assigner_visualization.py` script.
Args:
batch_data_samples (List[:obj:`DetDataSample`], dict): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
inputs_hw: Height and width of inputs size
Returns:
dict: A dictionary of assigning components.
"""
if isinstance(batch_data_samples, list):
raise NotImplementedError(
'assigning results_list is not implemented')
else:
# Fast version
cls_scores, bbox_preds = self(batch_data_samples['feats'])
assign_inputs = (cls_scores, bbox_preds,
batch_data_samples['bboxes_labels'],
batch_data_samples['img_metas'], inputs_hw)
assign_results = self.assign_by_gt_and_feat(*assign_inputs)
return assign_results

View File

@ -3,6 +3,7 @@ from typing import Union
from mmyolo.models import YOLODetector
from mmyolo.registry import MODELS
from projects.assigner_visualization.dense_heads import RTMHeadAssigner
@MODELS.register_module()
@ -22,6 +23,8 @@ class YOLODetectorAssigner(YOLODetector):
assert isinstance(data, dict)
assert len(data['inputs']) == 1, 'Only support batchsize == 1'
data = self.data_preprocessor(data, True)
if isinstance(self.bbox_head, RTMHeadAssigner):
data['data_samples']['feats'] = self.extract_feat(data['inputs'])
inputs_hw = data['inputs'].shape[-2:]
assign_results = self.bbox_head.assign(data['data_samples'], inputs_hw)
return assign_results

View File

@ -218,12 +218,17 @@ class YOLOAssignerVisualizer(DetLocalVisualizer):
with corresponding stride. Defaults to 0.5.
"""
palette = self.dataset_meta['PALETTE']
palette = self.dataset_meta['palette']
center_x = ((grid_x_inds + offset) * stride)
center_y = ((grid_y_inds + offset) * stride)
xyxy = torch.stack((center_x, center_y, center_x, center_y), dim=1)
assert self.priors_size is not None
xyxy += self.priors_size[feat_ind][prior_ind]
device = xyxy.device
if self.priors_size is not None:
xyxy += self.priors_size[feat_ind][prior_ind].to(device)
else:
xyxy += torch.tensor(
[[-stride / 2, -stride / 2, stride / 2, stride / 2]],
device=device)
colors = [palette[i] for i in class_inds]
self.draw_bboxes(
@ -284,7 +289,10 @@ class YOLOAssignerVisualizer(DetLocalVisualizer):
retained_gt_inds)
# draw title
base_prior = self.priors_size[feat_ind][prior_ind]
if self.priors_size is not None:
base_prior = self.priors_size[feat_ind][prior_ind]
else:
base_prior = [stride, stride, stride * 2, stride * 2]
prior_size = (base_prior[2] - base_prior[0],
base_prior[3] - base_prior[1])
pos = np.array((20, 20))