mirror of https://github.com/open-mmlab/mmyolo.git
RTMDet Assigner visualization (#528)
* fix format * return multiple pos assigns * rewrite to get matched_gt_inds * ignore corrupted images * rm RTMDetectorAssigner * fix bug for different devices * add warnings when use rtmdet without checkpoint * add priors for rtmdet * fix format * add readme * fix format * fix readme and typo * typo * fix notepull/551/head
parent
164c319493
commit
75618020f8
projects/assigner_visualization
|
@ -6,7 +6,7 @@
|
|||
|
||||
This project is developed for easily showing assigning results. The script allows users to analyze where and how many positive samples each gt is assigned in the image.
|
||||
|
||||
Now, the script only support `YOLOv5` .
|
||||
Now, the script supports `YOLOv5` and `RTMDet`.
|
||||
|
||||
## Usage
|
||||
|
||||
|
@ -15,3 +15,13 @@ Now, the script only support `YOLOv5` .
|
|||
```shell
|
||||
python projects/assigner_visualization/assigner_visualization.py projects/assigner_visualization/configs/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_assignervisualization.py
|
||||
```
|
||||
|
||||
Note: `YOLOv5` does not need to load the trained weights.
|
||||
|
||||
```shell
|
||||
python projects/assigner_visualization/assigner_visualization.py projects/assigner_visualization/configs/rtmdet_s_syncbn_fast_8xb32-300e_coco_assignervisualization.py -c ${checkpont}
|
||||
```
|
||||
|
||||
${checkpont} is the checkpont file path. Dynamic label assignment is used in `RTMDet`, model weights will affect the positive sample allocation results, so it is recommended to load the trained model weights.
|
||||
|
||||
If you want to know details about label assignment, you can check the [documentation](https://mmyolo.readthedocs.io/zh_CN/latest/algorithm_descriptions/rtmdet_description.html#id5).
|
||||
|
|
|
@ -3,6 +3,7 @@ import argparse
|
|||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -10,11 +11,13 @@ import torch
|
|||
from mmengine import ProgressBar
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.dataset import COLLATE_FUNCTIONS
|
||||
from mmengine.runner import load_checkpoint
|
||||
from numpy import random
|
||||
|
||||
from mmyolo.registry import DATASETS, MODELS
|
||||
from mmyolo.utils import register_all_modules
|
||||
from projects.assigner_visualization.dense_heads import YOLOv5HeadAssigner
|
||||
from projects.assigner_visualization.dense_heads import (RTMHeadAssigner,
|
||||
YOLOv5HeadAssigner)
|
||||
from projects.assigner_visualization.visualization import \
|
||||
YOLOAssignerVisualizer
|
||||
|
||||
|
@ -24,6 +27,7 @@ def parse_args():
|
|||
description='MMYOLO show the positive sample assigning'
|
||||
' results.')
|
||||
parser.add_argument('config', help='config file path')
|
||||
parser.add_argument('--checkpoint', '-c', type=str, help='checkpoint file')
|
||||
parser.add_argument(
|
||||
'--show-number',
|
||||
'-n',
|
||||
|
@ -82,11 +86,20 @@ def main():
|
|||
|
||||
# build model
|
||||
model = MODELS.build(cfg.model)
|
||||
assert isinstance(model.bbox_head, YOLOv5HeadAssigner),\
|
||||
'Now, this script only support yolov5, and bbox_head must use ' \
|
||||
'`YOLOv5HeadAssigner`. Please use `' \
|
||||
if args.checkpoint is not None:
|
||||
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
|
||||
elif isinstance(model.bbox_head, RTMHeadAssigner):
|
||||
warnings.warn(
|
||||
'if you use dynamic_assignment methods such as yolov7 or '
|
||||
'rtmdet assigner, please load the checkpoint.')
|
||||
|
||||
assert isinstance(model.bbox_head, (YOLOv5HeadAssigner, RTMHeadAssigner)),\
|
||||
'Now, this script only support yolov5 and rtmdet, and ' \
|
||||
'bbox_head must use ' \
|
||||
'`YOLOv5HeadAssigner or RTMHeadAssigner`. Please use `' \
|
||||
'yolov5_s-v61_syncbn_fast_8xb16-300e_coco_assignervisualization.py' \
|
||||
'` as config file.'
|
||||
'or rtmdet_s_syncbn_fast_8xb32-300e_coco_assignervisualization.py' \
|
||||
"""` as config file."""
|
||||
model.eval()
|
||||
model.to(args.device)
|
||||
|
||||
|
@ -107,7 +120,9 @@ def main():
|
|||
}], name='visualizer')
|
||||
visualizer.dataset_meta = dataset.metainfo
|
||||
# need priors size to draw priors
|
||||
visualizer.priors_size = model.bbox_head.prior_generator.base_anchors
|
||||
|
||||
if hasattr(model.bbox_head.prior_generator, 'base_anchors'):
|
||||
visualizer.priors_size = model.bbox_head.prior_generator.base_anchors
|
||||
|
||||
# make output dir
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
@ -120,7 +135,10 @@ def main():
|
|||
progress_bar = ProgressBar(display_number)
|
||||
for ind_img in range(display_number):
|
||||
data = dataset.prepare_data(ind_img)
|
||||
|
||||
if data is None:
|
||||
print('Unable to visualize {} due to strong data augmentations'.
|
||||
format(dataset[ind_img]['data_samples'].img_path))
|
||||
continue
|
||||
# convert data to batch format
|
||||
batch_data = collate_fn([data])
|
||||
with torch.no_grad():
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = ['../../../configs/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco.py']
|
||||
|
||||
custom_imports = dict(imports=[
|
||||
'projects.assigner_visualization.detectors',
|
||||
'projects.assigner_visualization.dense_heads'
|
||||
])
|
||||
|
||||
model = dict(
|
||||
type='YOLODetectorAssigner', bbox_head=dict(type='RTMHeadAssigner'))
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .rtmdet_head_assigner import RTMHeadAssigner
|
||||
from .yolov5_head_assigner import YOLOv5HeadAssigner
|
||||
|
||||
__all__ = ['YOLOv5HeadAssigner']
|
||||
__all__ = ['YOLOv5HeadAssigner', 'RTMHeadAssigner']
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from mmdet.structures.bbox import distance2bbox
|
||||
from mmdet.utils import InstanceList
|
||||
from torch import Tensor
|
||||
|
||||
from mmyolo.models import RTMDetHead
|
||||
from mmyolo.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class RTMHeadAssigner(RTMDetHead):
|
||||
|
||||
def assign_by_gt_and_feat(
|
||||
self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
batch_gt_instances: InstanceList,
|
||||
batch_img_metas: List[dict],
|
||||
inputs_hw: Union[Tensor, tuple] = (640, 640)
|
||||
) -> dict:
|
||||
"""Calculate the assigning results based on the gt and features
|
||||
extracted by the detection head.
|
||||
|
||||
Args:
|
||||
cls_scores (list[Tensor]): Box scores for each scale level
|
||||
Has shape (N, num_anchors * num_classes, H, W)
|
||||
bbox_preds (list[Tensor]): Decoded box for each scale
|
||||
level with shape (N, num_anchors * 4, H, W) in
|
||||
[tl_x, tl_y, br_x, br_y] format.
|
||||
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
||||
gt_instance. It usually includes ``bboxes`` and ``labels``
|
||||
attributes.
|
||||
batch_img_metas (list[dict]): Meta information of each image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
inputs_hw (Union[Tensor, tuple]): Height and width of inputs size.
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of assigning results.
|
||||
"""
|
||||
num_imgs = len(batch_img_metas)
|
||||
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
|
||||
assert len(featmap_sizes) == self.prior_generator.num_levels
|
||||
|
||||
gt_info = self.gt_instances_preprocess(batch_gt_instances, num_imgs)
|
||||
gt_labels = gt_info[:, :, :1]
|
||||
gt_bboxes = gt_info[:, :, 1:] # xyxy
|
||||
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
|
||||
|
||||
device = cls_scores[0].device
|
||||
|
||||
# If the shape does not equal, generate new one
|
||||
if featmap_sizes != self.featmap_sizes_train:
|
||||
self.featmap_sizes_train = featmap_sizes
|
||||
mlvl_priors_with_stride = self.prior_generator.grid_priors(
|
||||
featmap_sizes, device=device, with_stride=True)
|
||||
self.flatten_priors_train = torch.cat(
|
||||
mlvl_priors_with_stride, dim=0)
|
||||
|
||||
flatten_cls_scores = torch.cat([
|
||||
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
||||
self.cls_out_channels)
|
||||
for cls_score in cls_scores
|
||||
], 1).contiguous()
|
||||
|
||||
flatten_bboxes = torch.cat([
|
||||
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
||||
for bbox_pred in bbox_preds
|
||||
], 1)
|
||||
flatten_bboxes = flatten_bboxes * self.flatten_priors_train[..., -1,
|
||||
None]
|
||||
flatten_bboxes = distance2bbox(self.flatten_priors_train[..., :2],
|
||||
flatten_bboxes)
|
||||
|
||||
assigned_result = self.assigner(flatten_bboxes.detach(),
|
||||
flatten_cls_scores.detach(),
|
||||
self.flatten_priors_train, gt_labels,
|
||||
gt_bboxes, pad_bbox_flag)
|
||||
|
||||
labels = assigned_result['assigned_labels'].reshape(-1)
|
||||
bbox_targets = assigned_result['assigned_bboxes'].reshape(-1, 4)
|
||||
|
||||
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
|
||||
bg_class_ind = self.num_classes
|
||||
pos_inds = ((labels >= 0)
|
||||
& (labels < bg_class_ind)).nonzero().squeeze(1)
|
||||
targets = bbox_targets[pos_inds]
|
||||
gt_bboxes = gt_bboxes.squeeze(0)
|
||||
matched_gt_inds = torch.tensor(
|
||||
[((t == gt_bboxes).sum(dim=1) == t.shape[0]).nonzero()[0]
|
||||
for t in targets],
|
||||
device=device)
|
||||
|
||||
level_inds = torch.zeros_like(labels)
|
||||
img_inds = torch.zeros_like(labels)
|
||||
level_nums = [0] + [f[0] * f[1] for f in featmap_sizes]
|
||||
for i in range(len(level_nums) - 1):
|
||||
level_nums[i + 1] = level_nums[i] + level_nums[i + 1]
|
||||
level_inds[level_nums[i]:level_nums[i + 1]] = i
|
||||
level_inds_pos = level_inds[pos_inds]
|
||||
|
||||
img_inds = img_inds[pos_inds]
|
||||
labels = labels[pos_inds]
|
||||
|
||||
inputs_hw = batch_img_metas[0]['batch_input_shape']
|
||||
assign_results = []
|
||||
for i in range(self.num_levels):
|
||||
retained_inds = level_inds_pos == i
|
||||
if not retained_inds.any():
|
||||
assign_results_prior = {
|
||||
'stride':
|
||||
self.featmap_strides[i],
|
||||
'grid_x_inds':
|
||||
torch.zeros([0], dtype=torch.int64).to(device),
|
||||
'grid_y_inds':
|
||||
torch.zeros([0], dtype=torch.int64).to(device),
|
||||
'img_inds':
|
||||
torch.zeros([0], dtype=torch.int64).to(device),
|
||||
'class_inds':
|
||||
torch.zeros([0], dtype=torch.int64).to(device),
|
||||
'retained_gt_inds':
|
||||
torch.zeros([0], dtype=torch.int64).to(device),
|
||||
'prior_ind':
|
||||
0
|
||||
}
|
||||
else:
|
||||
w = inputs_hw[1] // self.featmap_strides[i]
|
||||
|
||||
retained_pos_inds = pos_inds[retained_inds] - level_nums[i]
|
||||
grid_y_inds = retained_pos_inds // w
|
||||
grid_x_inds = retained_pos_inds - retained_pos_inds // w * w
|
||||
assign_results_prior = {
|
||||
'stride': self.featmap_strides[i],
|
||||
'grid_x_inds': grid_x_inds,
|
||||
'grid_y_inds': grid_y_inds,
|
||||
'img_inds': img_inds[retained_inds],
|
||||
'class_inds': labels[retained_inds],
|
||||
'retained_gt_inds': matched_gt_inds[retained_inds],
|
||||
'prior_ind': 0
|
||||
}
|
||||
assign_results.append([assign_results_prior])
|
||||
return assign_results
|
||||
|
||||
def assign(self, batch_data_samples: Union[list, dict],
|
||||
inputs_hw: Union[tuple, torch.Size]) -> dict:
|
||||
"""Calculate assigning results. This function is provided to the
|
||||
`assigner_visualization.py` script.
|
||||
|
||||
Args:
|
||||
batch_data_samples (List[:obj:`DetDataSample`], dict): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
|
||||
inputs_hw: Height and width of inputs size
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of assigning components.
|
||||
"""
|
||||
if isinstance(batch_data_samples, list):
|
||||
raise NotImplementedError(
|
||||
'assigning results_list is not implemented')
|
||||
else:
|
||||
# Fast version
|
||||
cls_scores, bbox_preds = self(batch_data_samples['feats'])
|
||||
assign_inputs = (cls_scores, bbox_preds,
|
||||
batch_data_samples['bboxes_labels'],
|
||||
batch_data_samples['img_metas'], inputs_hw)
|
||||
assign_results = self.assign_by_gt_and_feat(*assign_inputs)
|
||||
return assign_results
|
|
@ -3,6 +3,7 @@ from typing import Union
|
|||
|
||||
from mmyolo.models import YOLODetector
|
||||
from mmyolo.registry import MODELS
|
||||
from projects.assigner_visualization.dense_heads import RTMHeadAssigner
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
@ -22,6 +23,8 @@ class YOLODetectorAssigner(YOLODetector):
|
|||
assert isinstance(data, dict)
|
||||
assert len(data['inputs']) == 1, 'Only support batchsize == 1'
|
||||
data = self.data_preprocessor(data, True)
|
||||
if isinstance(self.bbox_head, RTMHeadAssigner):
|
||||
data['data_samples']['feats'] = self.extract_feat(data['inputs'])
|
||||
inputs_hw = data['inputs'].shape[-2:]
|
||||
assign_results = self.bbox_head.assign(data['data_samples'], inputs_hw)
|
||||
return assign_results
|
||||
|
|
|
@ -218,12 +218,17 @@ class YOLOAssignerVisualizer(DetLocalVisualizer):
|
|||
with corresponding stride. Defaults to 0.5.
|
||||
"""
|
||||
|
||||
palette = self.dataset_meta['PALETTE']
|
||||
palette = self.dataset_meta['palette']
|
||||
center_x = ((grid_x_inds + offset) * stride)
|
||||
center_y = ((grid_y_inds + offset) * stride)
|
||||
xyxy = torch.stack((center_x, center_y, center_x, center_y), dim=1)
|
||||
assert self.priors_size is not None
|
||||
xyxy += self.priors_size[feat_ind][prior_ind]
|
||||
device = xyxy.device
|
||||
if self.priors_size is not None:
|
||||
xyxy += self.priors_size[feat_ind][prior_ind].to(device)
|
||||
else:
|
||||
xyxy += torch.tensor(
|
||||
[[-stride / 2, -stride / 2, stride / 2, stride / 2]],
|
||||
device=device)
|
||||
|
||||
colors = [palette[i] for i in class_inds]
|
||||
self.draw_bboxes(
|
||||
|
@ -284,7 +289,10 @@ class YOLOAssignerVisualizer(DetLocalVisualizer):
|
|||
retained_gt_inds)
|
||||
|
||||
# draw title
|
||||
base_prior = self.priors_size[feat_ind][prior_ind]
|
||||
if self.priors_size is not None:
|
||||
base_prior = self.priors_size[feat_ind][prior_ind]
|
||||
else:
|
||||
base_prior = [stride, stride, stride * 2, stride * 2]
|
||||
prior_size = (base_prior[2] - base_prior[0],
|
||||
base_prior[3] - base_prior[1])
|
||||
pos = np.array((20, 20))
|
||||
|
|
Loading…
Reference in New Issue