mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] Show YOLOv5 assigner results (#383)
* init commit * init commit * init commit * 定稿,开始重构 * format code * format code * add typehint and doc * init commit * init commit * init commit * 定稿,开始重构 * format code * format code * add typehint and doc * format code * rollback * add doc * fix less img bug * format code * format code * add README.md * beauty * beauty * uniform name * uniform name * uniform name * uniform namepull/413/head
parent
bb4aea90da
commit
9ef883187a
projects/assigner_visualization
|
@ -0,0 +1,17 @@
|
|||
# MMYOLO Model Assigner Visualization
|
||||
|
||||
<img src="https://user-images.githubusercontent.com/40284075/208255302-dbcf8cb0-b9d1-495f-8908-57dd2370dba8.png"/>
|
||||
|
||||
## Introduction
|
||||
|
||||
This project is developed for easily showing assigning results. The script allows users to analyze where and how many positive samples each gt is assigned in the image.
|
||||
|
||||
Now, the script only support `YOLOv5` .
|
||||
|
||||
## Usage
|
||||
|
||||
### Command
|
||||
|
||||
```shell
|
||||
python projects/assigner_visualization/assigner_visualization.py projects/assigner_visualization/configs/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_assignervisualization.py `
|
||||
```
|
|
@ -0,0 +1,151 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import ProgressBar
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.dataset import COLLATE_FUNCTIONS
|
||||
from numpy import random
|
||||
|
||||
from mmyolo.registry import DATASETS, MODELS
|
||||
from mmyolo.utils import register_all_modules
|
||||
from projects.assigner_visualization.dense_heads import YOLOv5HeadAssigner
|
||||
from projects.assigner_visualization.visualization import \
|
||||
YOLOAssignerVisualizer
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='MMYOLO show the positive sample assigning'
|
||||
' results.')
|
||||
parser.add_argument('config', help='config file path')
|
||||
parser.add_argument(
|
||||
'--show-number',
|
||||
'-n',
|
||||
type=int,
|
||||
default=sys.maxsize,
|
||||
help='number of images selected to save, '
|
||||
'must bigger than 0. if the number is bigger than length '
|
||||
'of dataset, show all the images in dataset; '
|
||||
'default "sys.maxsize", show all images in dataset')
|
||||
parser.add_argument(
|
||||
'--output-dir',
|
||||
default='assigned_results',
|
||||
type=str,
|
||||
help='The name of the folder where the image is saved.')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference.')
|
||||
parser.add_argument(
|
||||
'--show-prior',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Whether to show prior on image.')
|
||||
parser.add_argument(
|
||||
'--not-show-label',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Whether to show label on image.')
|
||||
parser.add_argument('--seed', default=-1, type=int, help='random seed')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
register_all_modules()
|
||||
|
||||
# set random seed
|
||||
seed = int(args.seed)
|
||||
if seed != -1:
|
||||
print(f'Set the global seed: {seed}')
|
||||
random.seed(int(args.seed))
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# build model
|
||||
model = MODELS.build(cfg.model)
|
||||
assert isinstance(model.bbox_head, YOLOv5HeadAssigner),\
|
||||
'Now, this script only support yolov5, and bbox_head must use ' \
|
||||
'`YOLOv5HeadAssigner`. Please use `' \
|
||||
'yolov5_s-v61_syncbn_fast_8xb16-300e_coco_assignervisualization.py' \
|
||||
'` as config file.'
|
||||
model.eval()
|
||||
model.to(args.device)
|
||||
|
||||
# build dataset
|
||||
dataset_cfg = cfg.get('train_dataloader').get('dataset')
|
||||
dataset = DATASETS.build(dataset_cfg)
|
||||
|
||||
# get collate_fn
|
||||
collate_fn_cfg = cfg.get('train_dataloader').pop(
|
||||
'collate_fn', dict(type='pseudo_collate'))
|
||||
collate_fn_type = collate_fn_cfg.pop('type')
|
||||
collate_fn = COLLATE_FUNCTIONS.get(collate_fn_type)
|
||||
|
||||
# init visualizer
|
||||
visualizer = YOLOAssignerVisualizer(
|
||||
vis_backends=[{
|
||||
'type': 'LocalVisBackend'
|
||||
}], name='visualizer')
|
||||
visualizer.dataset_meta = dataset.metainfo
|
||||
# need priors size to draw priors
|
||||
visualizer.priors_size = model.bbox_head.prior_generator.base_anchors
|
||||
|
||||
# make output dir
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# init visualization image number
|
||||
assert args.show_number > 0
|
||||
display_number = min(args.show_number, len(dataset))
|
||||
|
||||
progress_bar = ProgressBar(display_number)
|
||||
for ind_img in range(display_number):
|
||||
data = dataset.prepare_data(ind_img)
|
||||
|
||||
# convert data to batch format
|
||||
batch_data = collate_fn([data])
|
||||
with torch.no_grad():
|
||||
assign_results = model.assign(batch_data)
|
||||
|
||||
img = data['inputs'].cpu().numpy().astype(np.uint8).transpose(
|
||||
(1, 2, 0))
|
||||
# bgr2rgb
|
||||
img = mmcv.bgr2rgb(img)
|
||||
|
||||
gt_instances = data['data_samples'].gt_instances
|
||||
|
||||
img_show = visualizer.draw_assign(img, assign_results, gt_instances,
|
||||
args.show_prior, args.not_show_label)
|
||||
|
||||
if hasattr(data['data_samples'], 'img_path'):
|
||||
filename = osp.basename(data['data_samples'].img_path)
|
||||
else:
|
||||
# some dataset have not image path
|
||||
filename = f'{ind_img}.jpg'
|
||||
out_file = osp.join(args.output_dir, filename)
|
||||
|
||||
# convert rgb 2 bgr and save img
|
||||
mmcv.imwrite(mmcv.rgb2bgr(img_show), out_file)
|
||||
progress_bar.update()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,11 @@
|
|||
_base_ = [
|
||||
'../../../configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py'
|
||||
]
|
||||
|
||||
custom_imports = dict(imports=[
|
||||
'projects.assigner_visualization.detectors',
|
||||
'projects.assigner_visualization.dense_heads'
|
||||
])
|
||||
|
||||
model = dict(
|
||||
type='YOLODetectorAssigner', bbox_head=dict(type='YOLOv5HeadAssigner'))
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .yolov5_head_assigner import YOLOv5HeadAssigner
|
||||
|
||||
__all__ = ['YOLOv5HeadAssigner']
|
|
@ -0,0 +1,188 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence, Union
|
||||
|
||||
import torch
|
||||
from mmdet.models.utils import unpack_gt_instances
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmyolo.models import YOLOv5Head
|
||||
from mmyolo.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class YOLOv5HeadAssigner(YOLOv5Head):
|
||||
|
||||
def assign_by_gt_and_feat(
|
||||
self,
|
||||
batch_gt_instances: Sequence[InstanceData],
|
||||
batch_img_metas: Sequence[dict],
|
||||
inputs_hw: Union[Tensor, tuple] = (640, 640)
|
||||
) -> dict:
|
||||
"""Calculate the assigning results based on the gt and features
|
||||
extracted by the detection head.
|
||||
|
||||
Args:
|
||||
batch_gt_instances (Sequence[InstanceData]): Batch of
|
||||
gt_instance. It usually includes ``bboxes`` and ``labels``
|
||||
attributes.
|
||||
batch_img_metas (Sequence[dict]): Meta information of each image,
|
||||
e.g., image size, scaling factor, etc.
|
||||
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
|
||||
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
|
||||
data that is ignored during training and testing.
|
||||
Defaults to None.
|
||||
inputs_hw (Union[Tensor, tuple]): Height and width of inputs size.
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of assigning results.
|
||||
"""
|
||||
# 1. Convert gt to norm format
|
||||
batch_targets_normed = self._convert_gt_to_norm_format(
|
||||
batch_gt_instances, batch_img_metas)
|
||||
|
||||
device = batch_targets_normed.device
|
||||
scaled_factor = torch.ones(7, device=device)
|
||||
gt_inds = torch.arange(
|
||||
batch_targets_normed.shape[1],
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
requires_grad=False).unsqueeze(0).repeat((self.num_base_priors, 1))
|
||||
|
||||
assign_results = []
|
||||
for i in range(self.num_levels):
|
||||
assign_results_feat = []
|
||||
h = inputs_hw[0] // self.featmap_strides[i]
|
||||
w = inputs_hw[1] // self.featmap_strides[i]
|
||||
|
||||
# empty gt bboxes
|
||||
if batch_targets_normed.shape[1] == 0:
|
||||
for k in range(self.num_base_priors):
|
||||
assign_results_feat.append({
|
||||
'stride':
|
||||
self.featmap_strides[i],
|
||||
'grid_x_inds':
|
||||
torch.zeros([0], dtype=torch.int64).to(device),
|
||||
'grid_y_inds':
|
||||
torch.zeros([0], dtype=torch.int64).to(device),
|
||||
'img_inds':
|
||||
torch.zeros([0], dtype=torch.int64).to(device),
|
||||
'class_inds':
|
||||
torch.zeros([0], dtype=torch.int64).to(device),
|
||||
'retained_gt_inds':
|
||||
torch.zeros([0], dtype=torch.int64).to(device),
|
||||
'prior_ind':
|
||||
k
|
||||
})
|
||||
assign_results.append(assign_results_feat)
|
||||
continue
|
||||
|
||||
priors_base_sizes_i = self.priors_base_sizes[i]
|
||||
# feature map scale whwh
|
||||
scaled_factor[2:6] = torch.tensor([w, h, w, h])
|
||||
# Scale batch_targets from range 0-1 to range 0-features_maps size.
|
||||
# (num_base_priors, num_bboxes, 7)
|
||||
batch_targets_scaled = batch_targets_normed * scaled_factor
|
||||
|
||||
# 2. Shape match
|
||||
wh_ratio = batch_targets_scaled[...,
|
||||
4:6] / priors_base_sizes_i[:, None]
|
||||
match_inds = torch.max(
|
||||
wh_ratio, 1 / wh_ratio).max(2)[0] < self.prior_match_thr
|
||||
batch_targets_scaled = batch_targets_scaled[match_inds]
|
||||
match_gt_inds = gt_inds[match_inds]
|
||||
|
||||
# no gt bbox matches anchor
|
||||
if batch_targets_scaled.shape[0] == 0:
|
||||
for k in range(self.num_base_priors):
|
||||
assign_results_feat.append({
|
||||
'stride':
|
||||
self.featmap_strides[i],
|
||||
'grid_x_inds':
|
||||
torch.zeros([0], dtype=torch.int64).to(device),
|
||||
'grid_y_inds':
|
||||
torch.zeros([0], dtype=torch.int64).to(device),
|
||||
'img_inds':
|
||||
torch.zeros([0], dtype=torch.int64).to(device),
|
||||
'class_inds':
|
||||
torch.zeros([0], dtype=torch.int64).to(device),
|
||||
'retained_gt_inds':
|
||||
torch.zeros([0], dtype=torch.int64).to(device),
|
||||
'prior_ind':
|
||||
k
|
||||
})
|
||||
assign_results.append(assign_results_feat)
|
||||
continue
|
||||
|
||||
# 3. Positive samples with additional neighbors
|
||||
|
||||
# check the left, up, right, bottom sides of the
|
||||
# targets grid, and determine whether assigned
|
||||
# them as positive samples as well.
|
||||
batch_targets_cxcy = batch_targets_scaled[:, 2:4]
|
||||
grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy
|
||||
left, up = ((batch_targets_cxcy % 1 < self.near_neighbor_thr) &
|
||||
(batch_targets_cxcy > 1)).T
|
||||
right, bottom = ((grid_xy % 1 < self.near_neighbor_thr) &
|
||||
(grid_xy > 1)).T
|
||||
offset_inds = torch.stack(
|
||||
(torch.ones_like(left), left, up, right, bottom))
|
||||
|
||||
batch_targets_scaled = batch_targets_scaled.repeat(
|
||||
(5, 1, 1))[offset_inds]
|
||||
retained_gt_inds = match_gt_inds.repeat((5, 1))[offset_inds]
|
||||
retained_offsets = self.grid_offset.repeat(1, offset_inds.shape[1],
|
||||
1)[offset_inds]
|
||||
|
||||
# prepare pred results and positive sample indexes to
|
||||
# calculate class loss and bbox lo
|
||||
_chunk_targets = batch_targets_scaled.chunk(4, 1)
|
||||
img_class_inds, grid_xy, grid_wh, priors_inds = _chunk_targets
|
||||
priors_inds, (img_inds, class_inds) = priors_inds.long().view(
|
||||
-1), img_class_inds.long().T
|
||||
|
||||
grid_xy_long = (grid_xy -
|
||||
retained_offsets * self.near_neighbor_thr).long()
|
||||
grid_x_inds, grid_y_inds = grid_xy_long.T
|
||||
for k in range(self.num_base_priors):
|
||||
retained_inds = priors_inds == k
|
||||
assign_results_prior = {
|
||||
'stride': self.featmap_strides[i],
|
||||
'grid_x_inds': grid_x_inds[retained_inds],
|
||||
'grid_y_inds': grid_y_inds[retained_inds],
|
||||
'img_inds': img_inds[retained_inds],
|
||||
'class_inds': class_inds[retained_inds],
|
||||
'retained_gt_inds': retained_gt_inds[retained_inds],
|
||||
'prior_ind': k
|
||||
}
|
||||
assign_results_feat.append(assign_results_prior)
|
||||
assign_results.append(assign_results_feat)
|
||||
return assign_results
|
||||
|
||||
def assign(self, batch_data_samples: Union[list, dict],
|
||||
inputs_hw: Union[tuple, torch.Size]) -> dict:
|
||||
"""Calculate assigning results. This function is provided to the
|
||||
`assigner_visualization.py` script.
|
||||
|
||||
Args:
|
||||
batch_data_samples (List[:obj:`DetDataSample`], dict): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
|
||||
inputs_hw: Height and width of inputs size
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of assigning components.
|
||||
"""
|
||||
if isinstance(batch_data_samples, list):
|
||||
outputs = unpack_gt_instances(batch_data_samples)
|
||||
(batch_gt_instances, batch_gt_instances_ignore,
|
||||
batch_img_metas) = outputs
|
||||
|
||||
assign_inputs = (batch_gt_instances, batch_img_metas,
|
||||
batch_gt_instances_ignore, inputs_hw)
|
||||
else:
|
||||
# Fast version
|
||||
assign_inputs = (batch_data_samples['bboxes_labels'],
|
||||
batch_data_samples['img_metas'], inputs_hw)
|
||||
assign_results = self.assign_by_gt_and_feat(*assign_inputs)
|
||||
|
||||
return assign_results
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from projects.assigner_visualization.detectors.yolo_detector_assigner import \
|
||||
YOLODetectorAssigner
|
||||
|
||||
__all__ = ['YOLODetectorAssigner']
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Union
|
||||
|
||||
from mmyolo.models import YOLODetector
|
||||
from mmyolo.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class YOLODetectorAssigner(YOLODetector):
|
||||
|
||||
def assign(self, data: dict) -> Union[dict, list]:
|
||||
"""Calculate assigning results from a batch of inputs and data
|
||||
samples.This function is provided to the `assigner_visualization.py`
|
||||
script.
|
||||
|
||||
Args:
|
||||
data (dict or tuple or list): Data sampled from dataset.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of assigning components.
|
||||
"""
|
||||
assert isinstance(data, dict)
|
||||
assert len(data['inputs']) == 1, 'Only support batchsize == 1'
|
||||
data = self.data_preprocessor(data, True)
|
||||
inputs_hw = data['inputs'].shape[-2:]
|
||||
assign_results = self.bbox_head.assign(data['data_samples'], inputs_hw)
|
||||
return assign_results
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .assigner_visualizer import YOLOAssignerVisualizer
|
||||
|
||||
__all__ = ['YOLOAssignerVisualizer']
|
|
@ -0,0 +1,314 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmdet.structures.bbox import HorizontalBoxes
|
||||
from mmdet.visualization import DetLocalVisualizer
|
||||
from mmdet.visualization.palette import _get_adaptive_scales, get_palette
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmyolo.registry import VISUALIZERS
|
||||
|
||||
|
||||
@VISUALIZERS.register_module()
|
||||
class YOLOAssignerVisualizer(DetLocalVisualizer):
|
||||
"""MMYOLO Detection Assigner Visualizer.
|
||||
|
||||
This class is provided to the `assigner_visualization.py` script.
|
||||
Args:
|
||||
name (str): Name of the instance. Defaults to 'visualizer'.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = 'visualizer', *args, **kwargs):
|
||||
super().__init__(name=name, *args, **kwargs)
|
||||
# need priors_size from config
|
||||
self.priors_size = None
|
||||
|
||||
def draw_grid(self,
|
||||
stride: int = 8,
|
||||
line_styles: Union[str, List[str]] = ':',
|
||||
colors: Union[str, tuple, List[str],
|
||||
List[tuple]] = (180, 180, 180),
|
||||
line_widths: Union[Union[int, float],
|
||||
List[Union[int, float]]] = 1):
|
||||
"""Draw grids on image.
|
||||
|
||||
Args:
|
||||
stride (int): Downsample factor of feature map.
|
||||
line_styles (Union[str, List[str]]): The linestyle
|
||||
of lines. ``line_styles`` can have the same length with
|
||||
texts or just single value. If ``line_styles`` is single
|
||||
value, all the lines will have the same linestyle.
|
||||
Reference to
|
||||
https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle
|
||||
for more details. Defaults to ':'.
|
||||
colors (Union[str, tuple, List[str], List[tuple]]): The colors of
|
||||
lines. ``colors`` can have the same length with lines or just
|
||||
single value. If ``colors`` is single value, all the lines
|
||||
will have the same colors. Reference to
|
||||
https://matplotlib.org/stable/gallery/color/named_colors.html
|
||||
for more details. Defaults to (180, 180, 180).
|
||||
line_widths (Union[Union[int, float], List[Union[int, float]]]):
|
||||
The linewidth of lines. ``line_widths`` can have
|
||||
the same length with lines or just single value.
|
||||
If ``line_widths`` is single value, all the lines will
|
||||
have the same linewidth. Defaults to 1.
|
||||
"""
|
||||
assert self._image is not None, 'Please set image using `set_image`'
|
||||
# draw vertical lines
|
||||
x_datas_vertical = ((np.arange(self.width // stride - 1) + 1) *
|
||||
stride).reshape((-1, 1)).repeat(
|
||||
2, axis=1)
|
||||
y_datas_vertical = np.array([[0, self.height - 1]]).repeat(
|
||||
self.width // stride - 1, axis=0)
|
||||
self.draw_lines(
|
||||
x_datas_vertical,
|
||||
y_datas_vertical,
|
||||
colors=colors,
|
||||
line_styles=line_styles,
|
||||
line_widths=line_widths)
|
||||
|
||||
# draw horizontal lines
|
||||
x_datas_horizontal = np.array([[0, self.width - 1]]).repeat(
|
||||
self.height // stride - 1, axis=0)
|
||||
y_datas_horizontal = ((np.arange(self.height // stride - 1) + 1) *
|
||||
stride).reshape((-1, 1)).repeat(
|
||||
2, axis=1)
|
||||
self.draw_lines(
|
||||
x_datas_horizontal,
|
||||
y_datas_horizontal,
|
||||
colors=colors,
|
||||
line_styles=line_styles,
|
||||
line_widths=line_widths)
|
||||
|
||||
def draw_instances_assign(self,
|
||||
instances: InstanceData,
|
||||
retained_gt_inds: Tensor,
|
||||
not_show_label: bool = False):
|
||||
"""Draw instances of GT.
|
||||
|
||||
Args:
|
||||
instances (:obj:`InstanceData`): gt_instance. It usually
|
||||
includes ``bboxes`` and ``labels`` attributes.
|
||||
retained_gt_inds (Tensor): The gt indexes assigned as the
|
||||
positive sample in the current prior.
|
||||
not_show_label (bool): Whether to show gt labels on images.
|
||||
"""
|
||||
assert self.dataset_meta is not None
|
||||
classes = self.dataset_meta['CLASSES']
|
||||
palette = self.dataset_meta['PALETTE']
|
||||
if len(retained_gt_inds) == 0:
|
||||
return self.get_image()
|
||||
draw_gt_inds = torch.from_numpy(
|
||||
np.array(
|
||||
list(set(retained_gt_inds.cpu().numpy())), dtype=np.int64))
|
||||
bboxes = instances.bboxes[draw_gt_inds]
|
||||
labels = instances.labels[draw_gt_inds]
|
||||
|
||||
if not isinstance(bboxes, Tensor):
|
||||
bboxes = bboxes.tensor
|
||||
|
||||
edge_colors = [palette[i] for i in labels]
|
||||
|
||||
max_label = int(max(labels) if len(labels) > 0 else 0)
|
||||
text_palette = get_palette(self.text_color, max_label + 1)
|
||||
text_colors = [text_palette[label] for label in labels]
|
||||
|
||||
self.draw_bboxes(
|
||||
bboxes,
|
||||
edge_colors=edge_colors,
|
||||
alpha=self.alpha,
|
||||
line_widths=self.line_width)
|
||||
|
||||
if not not_show_label:
|
||||
positions = bboxes[:, :2] + self.line_width
|
||||
areas = (bboxes[:, 3] - bboxes[:, 1]) * (
|
||||
bboxes[:, 2] - bboxes[:, 0])
|
||||
scales = _get_adaptive_scales(areas)
|
||||
for i, (pos, label) in enumerate(zip(positions, labels)):
|
||||
label_text = classes[
|
||||
label] if classes is not None else f'class {label}'
|
||||
|
||||
self.draw_texts(
|
||||
label_text,
|
||||
pos,
|
||||
colors=text_colors[i],
|
||||
font_sizes=int(13 * scales[i]),
|
||||
bboxes=[{
|
||||
'facecolor': 'black',
|
||||
'alpha': 0.8,
|
||||
'pad': 0.7,
|
||||
'edgecolor': 'none'
|
||||
}])
|
||||
|
||||
def draw_positive_assign(self,
|
||||
grid_x_inds: Tensor,
|
||||
grid_y_inds: Tensor,
|
||||
class_inds: Tensor,
|
||||
stride: int,
|
||||
bboxes: Union[Tensor, HorizontalBoxes],
|
||||
retained_gt_inds: Tensor,
|
||||
offset: float = 0.5):
|
||||
"""
|
||||
|
||||
Args:
|
||||
grid_x_inds (Tensor): The X-axis indexes of the positive sample
|
||||
in current prior.
|
||||
grid_y_inds (Tensor): The Y-axis indexes of the positive sample
|
||||
in current prior.
|
||||
class_inds (Tensor): The classes indexes of the positive sample
|
||||
in current prior.
|
||||
stride (int): Downsample factor of feature map.
|
||||
bboxes (Union[Tensor, HorizontalBoxes]): Bounding boxes of GT.
|
||||
retained_gt_inds (Tensor): The gt indexes assigned as the
|
||||
positive sample in the current prior.
|
||||
offset (float): The offset of points, the value is normalized
|
||||
with corresponding stride. Defaults to 0.5.
|
||||
"""
|
||||
if not isinstance(bboxes, Tensor):
|
||||
# Convert HorizontalBoxes to Tensor
|
||||
bboxes = bboxes.tensor
|
||||
|
||||
# The PALETTE in the dataset_meta is required
|
||||
assert self.dataset_meta is not None
|
||||
palette = self.dataset_meta['PALETTE']
|
||||
x = ((grid_x_inds + offset) * stride).long()
|
||||
y = ((grid_y_inds + offset) * stride).long()
|
||||
center = torch.stack((x, y), dim=-1)
|
||||
|
||||
retained_bboxes = bboxes[retained_gt_inds]
|
||||
bbox_wh = retained_bboxes[:, 2:] - retained_bboxes[:, :2]
|
||||
bbox_area = bbox_wh[:, 0] * bbox_wh[:, 1]
|
||||
radius = _get_adaptive_scales(bbox_area) * 4
|
||||
colors = [palette[i] for i in class_inds]
|
||||
|
||||
self.draw_circles(
|
||||
center,
|
||||
radius,
|
||||
colors,
|
||||
line_widths=0,
|
||||
face_colors=colors,
|
||||
alpha=1.0)
|
||||
|
||||
def draw_prior(self,
|
||||
grid_x_inds: Tensor,
|
||||
grid_y_inds: Tensor,
|
||||
class_inds: Tensor,
|
||||
stride: int,
|
||||
feat_ind: int,
|
||||
prior_ind: int,
|
||||
offset: float = 0.5):
|
||||
"""Draw priors on image.
|
||||
|
||||
Args:
|
||||
grid_x_inds (Tensor): The X-axis indexes of the positive sample
|
||||
in current prior.
|
||||
grid_y_inds (Tensor): The Y-axis indexes of the positive sample
|
||||
in current prior.
|
||||
class_inds (Tensor): The classes indexes of the positive sample
|
||||
in current prior.
|
||||
stride (int): Downsample factor of feature map.
|
||||
feat_ind (int): Index of featmap.
|
||||
prior_ind (int): Index of prior in current featmap.
|
||||
offset (float): The offset of points, the value is normalized
|
||||
with corresponding stride. Defaults to 0.5.
|
||||
"""
|
||||
|
||||
palette = self.dataset_meta['PALETTE']
|
||||
center_x = ((grid_x_inds + offset) * stride)
|
||||
center_y = ((grid_y_inds + offset) * stride)
|
||||
xyxy = torch.stack((center_x, center_y, center_x, center_y), dim=1)
|
||||
assert self.priors_size is not None
|
||||
xyxy += self.priors_size[feat_ind][prior_ind]
|
||||
|
||||
colors = [palette[i] for i in class_inds]
|
||||
self.draw_bboxes(
|
||||
xyxy,
|
||||
edge_colors=colors,
|
||||
alpha=self.alpha,
|
||||
line_styles='--',
|
||||
line_widths=math.ceil(self.line_width * 0.3))
|
||||
|
||||
def draw_assign(self,
|
||||
image: np.ndarray,
|
||||
assign_results: List[List[dict]],
|
||||
gt_instances: InstanceData,
|
||||
show_prior: bool = False,
|
||||
not_show_label: bool = False) -> np.ndarray:
|
||||
"""Draw assigning results.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The image to draw.
|
||||
assign_results (list): The assigning results.
|
||||
gt_instances (:obj:`InstanceData`): Data structure for
|
||||
instance-level annotations or predictions.
|
||||
show_prior (bool): Whether to show prior on image.
|
||||
not_show_label (bool): Whether to show gt labels on images.
|
||||
|
||||
Returns:
|
||||
np.ndarray: the drawn image which channel is RGB.
|
||||
"""
|
||||
img_show_list = []
|
||||
for feat_ind, assign_results_feat in enumerate(assign_results):
|
||||
img_show_list_feat = []
|
||||
for prior_ind, assign_results_prior in enumerate(
|
||||
assign_results_feat):
|
||||
self.set_image(image)
|
||||
h, w = image.shape[:2]
|
||||
|
||||
# draw grid
|
||||
stride = assign_results_prior['stride']
|
||||
self.draw_grid(stride)
|
||||
|
||||
# draw prior on matched gt
|
||||
grid_x_inds = assign_results_prior['grid_x_inds']
|
||||
grid_y_inds = assign_results_prior['grid_y_inds']
|
||||
class_inds = assign_results_prior['class_inds']
|
||||
prior_ind = assign_results_prior['prior_ind']
|
||||
if show_prior:
|
||||
self.draw_prior(grid_x_inds, grid_y_inds, class_inds,
|
||||
stride, feat_ind, prior_ind)
|
||||
|
||||
# draw matched gt
|
||||
retained_gt_inds = assign_results_prior['retained_gt_inds']
|
||||
self.draw_instances_assign(gt_instances, retained_gt_inds,
|
||||
not_show_label)
|
||||
|
||||
# draw positive
|
||||
self.draw_positive_assign(grid_x_inds, grid_y_inds, class_inds,
|
||||
stride, gt_instances.bboxes,
|
||||
retained_gt_inds)
|
||||
|
||||
# draw title
|
||||
base_prior = self.priors_size[feat_ind][prior_ind]
|
||||
prior_size = (base_prior[2] - base_prior[0],
|
||||
base_prior[3] - base_prior[1])
|
||||
pos = np.array((20, 20))
|
||||
text = f'feat_ind: {feat_ind} ' \
|
||||
f'prior_ind: {prior_ind} ' \
|
||||
f'prior_size: ({prior_size[0]}, {prior_size[1]})'
|
||||
scales = _get_adaptive_scales(np.array([h * w / 16]))
|
||||
font_sizes = int(13 * scales)
|
||||
self.draw_texts(
|
||||
text,
|
||||
pos,
|
||||
colors=self.text_color,
|
||||
font_sizes=font_sizes,
|
||||
bboxes=[{
|
||||
'facecolor': 'black',
|
||||
'alpha': 0.8,
|
||||
'pad': 0.7,
|
||||
'edgecolor': 'none'
|
||||
}])
|
||||
|
||||
img_show = self.get_image()
|
||||
img_show = mmcv.impad(img_show, padding=(5, 5, 5, 5))
|
||||
img_show_list_feat.append(img_show)
|
||||
img_show_list.append(np.concatenate(img_show_list_feat, axis=1))
|
||||
|
||||
# Merge all images into one image
|
||||
return np.concatenate(img_show_list, axis=0)
|
Loading…
Reference in New Issue