diff --git a/projects/assigner_visualization/README.md b/projects/assigner_visualization/README.md new file mode 100644 index 00000000..579ab5c8 --- /dev/null +++ b/projects/assigner_visualization/README.md @@ -0,0 +1,17 @@ +# MMYOLO Model Assigner Visualization + + + +## Introduction + +This project is developed for easily showing assigning results. The script allows users to analyze where and how many positive samples each gt is assigned in the image. + +Now, the script only support `YOLOv5` . + +## Usage + +### Command + +```shell +python projects/assigner_visualization/assigner_visualization.py projects/assigner_visualization/configs/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_assignervisualization.py ` +``` diff --git a/projects/assigner_visualization/assigner_visualization.py b/projects/assigner_visualization/assigner_visualization.py new file mode 100644 index 00000000..df489c43 --- /dev/null +++ b/projects/assigner_visualization/assigner_visualization.py @@ -0,0 +1,151 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp +import sys + +import mmcv +import numpy as np +import torch +from mmengine import ProgressBar +from mmengine.config import Config, DictAction +from mmengine.dataset import COLLATE_FUNCTIONS +from numpy import random + +from mmyolo.registry import DATASETS, MODELS +from mmyolo.utils import register_all_modules +from projects.assigner_visualization.dense_heads import YOLOv5HeadAssigner +from projects.assigner_visualization.visualization import \ + YOLOAssignerVisualizer + + +def parse_args(): + parser = argparse.ArgumentParser( + description='MMYOLO show the positive sample assigning' + ' results.') + parser.add_argument('config', help='config file path') + parser.add_argument( + '--show-number', + '-n', + type=int, + default=sys.maxsize, + help='number of images selected to save, ' + 'must bigger than 0. if the number is bigger than length ' + 'of dataset, show all the images in dataset; ' + 'default "sys.maxsize", show all images in dataset') + parser.add_argument( + '--output-dir', + default='assigned_results', + type=str, + help='The name of the folder where the image is saved.') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference.') + parser.add_argument( + '--show-prior', + default=False, + action='store_true', + help='Whether to show prior on image.') + parser.add_argument( + '--not-show-label', + default=False, + action='store_true', + help='Whether to show label on image.') + parser.add_argument('--seed', default=-1, type=int, help='random seed') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + register_all_modules() + + # set random seed + seed = int(args.seed) + if seed != -1: + print(f'Set the global seed: {seed}') + random.seed(int(args.seed)) + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # build model + model = MODELS.build(cfg.model) + assert isinstance(model.bbox_head, YOLOv5HeadAssigner),\ + 'Now, this script only support yolov5, and bbox_head must use ' \ + '`YOLOv5HeadAssigner`. Please use `' \ + 'yolov5_s-v61_syncbn_fast_8xb16-300e_coco_assignervisualization.py' \ + '` as config file.' + model.eval() + model.to(args.device) + + # build dataset + dataset_cfg = cfg.get('train_dataloader').get('dataset') + dataset = DATASETS.build(dataset_cfg) + + # get collate_fn + collate_fn_cfg = cfg.get('train_dataloader').pop( + 'collate_fn', dict(type='pseudo_collate')) + collate_fn_type = collate_fn_cfg.pop('type') + collate_fn = COLLATE_FUNCTIONS.get(collate_fn_type) + + # init visualizer + visualizer = YOLOAssignerVisualizer( + vis_backends=[{ + 'type': 'LocalVisBackend' + }], name='visualizer') + visualizer.dataset_meta = dataset.metainfo + # need priors size to draw priors + visualizer.priors_size = model.bbox_head.prior_generator.base_anchors + + # make output dir + os.makedirs(args.output_dir, exist_ok=True) + + # init visualization image number + assert args.show_number > 0 + display_number = min(args.show_number, len(dataset)) + + progress_bar = ProgressBar(display_number) + for ind_img in range(display_number): + data = dataset.prepare_data(ind_img) + + # convert data to batch format + batch_data = collate_fn([data]) + with torch.no_grad(): + assign_results = model.assign(batch_data) + + img = data['inputs'].cpu().numpy().astype(np.uint8).transpose( + (1, 2, 0)) + # bgr2rgb + img = mmcv.bgr2rgb(img) + + gt_instances = data['data_samples'].gt_instances + + img_show = visualizer.draw_assign(img, assign_results, gt_instances, + args.show_prior, args.not_show_label) + + if hasattr(data['data_samples'], 'img_path'): + filename = osp.basename(data['data_samples'].img_path) + else: + # some dataset have not image path + filename = f'{ind_img}.jpg' + out_file = osp.join(args.output_dir, filename) + + # convert rgb 2 bgr and save img + mmcv.imwrite(mmcv.rgb2bgr(img_show), out_file) + progress_bar.update() + + +if __name__ == '__main__': + main() diff --git a/projects/assigner_visualization/configs/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_assignervisualization.py b/projects/assigner_visualization/configs/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_assignervisualization.py new file mode 100644 index 00000000..1db799b5 --- /dev/null +++ b/projects/assigner_visualization/configs/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_assignervisualization.py @@ -0,0 +1,11 @@ +_base_ = [ + '../../../configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py' +] + +custom_imports = dict(imports=[ + 'projects.assigner_visualization.detectors', + 'projects.assigner_visualization.dense_heads' +]) + +model = dict( + type='YOLODetectorAssigner', bbox_head=dict(type='YOLOv5HeadAssigner')) diff --git a/projects/assigner_visualization/dense_heads/__init__.py b/projects/assigner_visualization/dense_heads/__init__.py new file mode 100644 index 00000000..c8e368d9 --- /dev/null +++ b/projects/assigner_visualization/dense_heads/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .yolov5_head_assigner import YOLOv5HeadAssigner + +__all__ = ['YOLOv5HeadAssigner'] diff --git a/projects/assigner_visualization/dense_heads/yolov5_head_assigner.py b/projects/assigner_visualization/dense_heads/yolov5_head_assigner.py new file mode 100644 index 00000000..599963fe --- /dev/null +++ b/projects/assigner_visualization/dense_heads/yolov5_head_assigner.py @@ -0,0 +1,188 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence, Union + +import torch +from mmdet.models.utils import unpack_gt_instances +from mmengine.structures import InstanceData +from torch import Tensor + +from mmyolo.models import YOLOv5Head +from mmyolo.registry import MODELS + + +@MODELS.register_module() +class YOLOv5HeadAssigner(YOLOv5Head): + + def assign_by_gt_and_feat( + self, + batch_gt_instances: Sequence[InstanceData], + batch_img_metas: Sequence[dict], + inputs_hw: Union[Tensor, tuple] = (640, 640) + ) -> dict: + """Calculate the assigning results based on the gt and features + extracted by the detection head. + + Args: + batch_gt_instances (Sequence[InstanceData]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (Sequence[dict]): Meta information of each image, + e.g., image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + inputs_hw (Union[Tensor, tuple]): Height and width of inputs size. + Returns: + dict[str, Tensor]: A dictionary of assigning results. + """ + # 1. Convert gt to norm format + batch_targets_normed = self._convert_gt_to_norm_format( + batch_gt_instances, batch_img_metas) + + device = batch_targets_normed.device + scaled_factor = torch.ones(7, device=device) + gt_inds = torch.arange( + batch_targets_normed.shape[1], + dtype=torch.long, + device=device, + requires_grad=False).unsqueeze(0).repeat((self.num_base_priors, 1)) + + assign_results = [] + for i in range(self.num_levels): + assign_results_feat = [] + h = inputs_hw[0] // self.featmap_strides[i] + w = inputs_hw[1] // self.featmap_strides[i] + + # empty gt bboxes + if batch_targets_normed.shape[1] == 0: + for k in range(self.num_base_priors): + assign_results_feat.append({ + 'stride': + self.featmap_strides[i], + 'grid_x_inds': + torch.zeros([0], dtype=torch.int64).to(device), + 'grid_y_inds': + torch.zeros([0], dtype=torch.int64).to(device), + 'img_inds': + torch.zeros([0], dtype=torch.int64).to(device), + 'class_inds': + torch.zeros([0], dtype=torch.int64).to(device), + 'retained_gt_inds': + torch.zeros([0], dtype=torch.int64).to(device), + 'prior_ind': + k + }) + assign_results.append(assign_results_feat) + continue + + priors_base_sizes_i = self.priors_base_sizes[i] + # feature map scale whwh + scaled_factor[2:6] = torch.tensor([w, h, w, h]) + # Scale batch_targets from range 0-1 to range 0-features_maps size. + # (num_base_priors, num_bboxes, 7) + batch_targets_scaled = batch_targets_normed * scaled_factor + + # 2. Shape match + wh_ratio = batch_targets_scaled[..., + 4:6] / priors_base_sizes_i[:, None] + match_inds = torch.max( + wh_ratio, 1 / wh_ratio).max(2)[0] < self.prior_match_thr + batch_targets_scaled = batch_targets_scaled[match_inds] + match_gt_inds = gt_inds[match_inds] + + # no gt bbox matches anchor + if batch_targets_scaled.shape[0] == 0: + for k in range(self.num_base_priors): + assign_results_feat.append({ + 'stride': + self.featmap_strides[i], + 'grid_x_inds': + torch.zeros([0], dtype=torch.int64).to(device), + 'grid_y_inds': + torch.zeros([0], dtype=torch.int64).to(device), + 'img_inds': + torch.zeros([0], dtype=torch.int64).to(device), + 'class_inds': + torch.zeros([0], dtype=torch.int64).to(device), + 'retained_gt_inds': + torch.zeros([0], dtype=torch.int64).to(device), + 'prior_ind': + k + }) + assign_results.append(assign_results_feat) + continue + + # 3. Positive samples with additional neighbors + + # check the left, up, right, bottom sides of the + # targets grid, and determine whether assigned + # them as positive samples as well. + batch_targets_cxcy = batch_targets_scaled[:, 2:4] + grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy + left, up = ((batch_targets_cxcy % 1 < self.near_neighbor_thr) & + (batch_targets_cxcy > 1)).T + right, bottom = ((grid_xy % 1 < self.near_neighbor_thr) & + (grid_xy > 1)).T + offset_inds = torch.stack( + (torch.ones_like(left), left, up, right, bottom)) + + batch_targets_scaled = batch_targets_scaled.repeat( + (5, 1, 1))[offset_inds] + retained_gt_inds = match_gt_inds.repeat((5, 1))[offset_inds] + retained_offsets = self.grid_offset.repeat(1, offset_inds.shape[1], + 1)[offset_inds] + + # prepare pred results and positive sample indexes to + # calculate class loss and bbox lo + _chunk_targets = batch_targets_scaled.chunk(4, 1) + img_class_inds, grid_xy, grid_wh, priors_inds = _chunk_targets + priors_inds, (img_inds, class_inds) = priors_inds.long().view( + -1), img_class_inds.long().T + + grid_xy_long = (grid_xy - + retained_offsets * self.near_neighbor_thr).long() + grid_x_inds, grid_y_inds = grid_xy_long.T + for k in range(self.num_base_priors): + retained_inds = priors_inds == k + assign_results_prior = { + 'stride': self.featmap_strides[i], + 'grid_x_inds': grid_x_inds[retained_inds], + 'grid_y_inds': grid_y_inds[retained_inds], + 'img_inds': img_inds[retained_inds], + 'class_inds': class_inds[retained_inds], + 'retained_gt_inds': retained_gt_inds[retained_inds], + 'prior_ind': k + } + assign_results_feat.append(assign_results_prior) + assign_results.append(assign_results_feat) + return assign_results + + def assign(self, batch_data_samples: Union[list, dict], + inputs_hw: Union[tuple, torch.Size]) -> dict: + """Calculate assigning results. This function is provided to the + `assigner_visualization.py` script. + + Args: + batch_data_samples (List[:obj:`DetDataSample`], dict): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + inputs_hw: Height and width of inputs size + + Returns: + dict: A dictionary of assigning components. + """ + if isinstance(batch_data_samples, list): + outputs = unpack_gt_instances(batch_data_samples) + (batch_gt_instances, batch_gt_instances_ignore, + batch_img_metas) = outputs + + assign_inputs = (batch_gt_instances, batch_img_metas, + batch_gt_instances_ignore, inputs_hw) + else: + # Fast version + assign_inputs = (batch_data_samples['bboxes_labels'], + batch_data_samples['img_metas'], inputs_hw) + assign_results = self.assign_by_gt_and_feat(*assign_inputs) + + return assign_results diff --git a/projects/assigner_visualization/detectors/__init__.py b/projects/assigner_visualization/detectors/__init__.py new file mode 100644 index 00000000..155606a0 --- /dev/null +++ b/projects/assigner_visualization/detectors/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from projects.assigner_visualization.detectors.yolo_detector_assigner import \ + YOLODetectorAssigner + +__all__ = ['YOLODetectorAssigner'] diff --git a/projects/assigner_visualization/detectors/yolo_detector_assigner.py b/projects/assigner_visualization/detectors/yolo_detector_assigner.py new file mode 100644 index 00000000..394f8a06 --- /dev/null +++ b/projects/assigner_visualization/detectors/yolo_detector_assigner.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +from mmyolo.models import YOLODetector +from mmyolo.registry import MODELS + + +@MODELS.register_module() +class YOLODetectorAssigner(YOLODetector): + + def assign(self, data: dict) -> Union[dict, list]: + """Calculate assigning results from a batch of inputs and data + samples.This function is provided to the `assigner_visualization.py` + script. + + Args: + data (dict or tuple or list): Data sampled from dataset. + + Returns: + dict: A dictionary of assigning components. + """ + assert isinstance(data, dict) + assert len(data['inputs']) == 1, 'Only support batchsize == 1' + data = self.data_preprocessor(data, True) + inputs_hw = data['inputs'].shape[-2:] + assign_results = self.bbox_head.assign(data['data_samples'], inputs_hw) + return assign_results diff --git a/projects/assigner_visualization/visualization/__init__.py b/projects/assigner_visualization/visualization/__init__.py new file mode 100644 index 00000000..521a25b8 --- /dev/null +++ b/projects/assigner_visualization/visualization/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .assigner_visualizer import YOLOAssignerVisualizer + +__all__ = ['YOLOAssignerVisualizer'] diff --git a/projects/assigner_visualization/visualization/assigner_visualizer.py b/projects/assigner_visualization/visualization/assigner_visualizer.py new file mode 100644 index 00000000..299ee22b --- /dev/null +++ b/projects/assigner_visualization/visualization/assigner_visualizer.py @@ -0,0 +1,314 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import List, Union + +import mmcv +import numpy as np +import torch +from mmdet.structures.bbox import HorizontalBoxes +from mmdet.visualization import DetLocalVisualizer +from mmdet.visualization.palette import _get_adaptive_scales, get_palette +from mmengine.structures import InstanceData +from torch import Tensor + +from mmyolo.registry import VISUALIZERS + + +@VISUALIZERS.register_module() +class YOLOAssignerVisualizer(DetLocalVisualizer): + """MMYOLO Detection Assigner Visualizer. + + This class is provided to the `assigner_visualization.py` script. + Args: + name (str): Name of the instance. Defaults to 'visualizer'. + """ + + def __init__(self, name: str = 'visualizer', *args, **kwargs): + super().__init__(name=name, *args, **kwargs) + # need priors_size from config + self.priors_size = None + + def draw_grid(self, + stride: int = 8, + line_styles: Union[str, List[str]] = ':', + colors: Union[str, tuple, List[str], + List[tuple]] = (180, 180, 180), + line_widths: Union[Union[int, float], + List[Union[int, float]]] = 1): + """Draw grids on image. + + Args: + stride (int): Downsample factor of feature map. + line_styles (Union[str, List[str]]): The linestyle + of lines. ``line_styles`` can have the same length with + texts or just single value. If ``line_styles`` is single + value, all the lines will have the same linestyle. + Reference to + https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle + for more details. Defaults to ':'. + colors (Union[str, tuple, List[str], List[tuple]]): The colors of + lines. ``colors`` can have the same length with lines or just + single value. If ``colors`` is single value, all the lines + will have the same colors. Reference to + https://matplotlib.org/stable/gallery/color/named_colors.html + for more details. Defaults to (180, 180, 180). + line_widths (Union[Union[int, float], List[Union[int, float]]]): + The linewidth of lines. ``line_widths`` can have + the same length with lines or just single value. + If ``line_widths`` is single value, all the lines will + have the same linewidth. Defaults to 1. + """ + assert self._image is not None, 'Please set image using `set_image`' + # draw vertical lines + x_datas_vertical = ((np.arange(self.width // stride - 1) + 1) * + stride).reshape((-1, 1)).repeat( + 2, axis=1) + y_datas_vertical = np.array([[0, self.height - 1]]).repeat( + self.width // stride - 1, axis=0) + self.draw_lines( + x_datas_vertical, + y_datas_vertical, + colors=colors, + line_styles=line_styles, + line_widths=line_widths) + + # draw horizontal lines + x_datas_horizontal = np.array([[0, self.width - 1]]).repeat( + self.height // stride - 1, axis=0) + y_datas_horizontal = ((np.arange(self.height // stride - 1) + 1) * + stride).reshape((-1, 1)).repeat( + 2, axis=1) + self.draw_lines( + x_datas_horizontal, + y_datas_horizontal, + colors=colors, + line_styles=line_styles, + line_widths=line_widths) + + def draw_instances_assign(self, + instances: InstanceData, + retained_gt_inds: Tensor, + not_show_label: bool = False): + """Draw instances of GT. + + Args: + instances (:obj:`InstanceData`): gt_instance. It usually + includes ``bboxes`` and ``labels`` attributes. + retained_gt_inds (Tensor): The gt indexes assigned as the + positive sample in the current prior. + not_show_label (bool): Whether to show gt labels on images. + """ + assert self.dataset_meta is not None + classes = self.dataset_meta['CLASSES'] + palette = self.dataset_meta['PALETTE'] + if len(retained_gt_inds) == 0: + return self.get_image() + draw_gt_inds = torch.from_numpy( + np.array( + list(set(retained_gt_inds.cpu().numpy())), dtype=np.int64)) + bboxes = instances.bboxes[draw_gt_inds] + labels = instances.labels[draw_gt_inds] + + if not isinstance(bboxes, Tensor): + bboxes = bboxes.tensor + + edge_colors = [palette[i] for i in labels] + + max_label = int(max(labels) if len(labels) > 0 else 0) + text_palette = get_palette(self.text_color, max_label + 1) + text_colors = [text_palette[label] for label in labels] + + self.draw_bboxes( + bboxes, + edge_colors=edge_colors, + alpha=self.alpha, + line_widths=self.line_width) + + if not not_show_label: + positions = bboxes[:, :2] + self.line_width + areas = (bboxes[:, 3] - bboxes[:, 1]) * ( + bboxes[:, 2] - bboxes[:, 0]) + scales = _get_adaptive_scales(areas) + for i, (pos, label) in enumerate(zip(positions, labels)): + label_text = classes[ + label] if classes is not None else f'class {label}' + + self.draw_texts( + label_text, + pos, + colors=text_colors[i], + font_sizes=int(13 * scales[i]), + bboxes=[{ + 'facecolor': 'black', + 'alpha': 0.8, + 'pad': 0.7, + 'edgecolor': 'none' + }]) + + def draw_positive_assign(self, + grid_x_inds: Tensor, + grid_y_inds: Tensor, + class_inds: Tensor, + stride: int, + bboxes: Union[Tensor, HorizontalBoxes], + retained_gt_inds: Tensor, + offset: float = 0.5): + """ + + Args: + grid_x_inds (Tensor): The X-axis indexes of the positive sample + in current prior. + grid_y_inds (Tensor): The Y-axis indexes of the positive sample + in current prior. + class_inds (Tensor): The classes indexes of the positive sample + in current prior. + stride (int): Downsample factor of feature map. + bboxes (Union[Tensor, HorizontalBoxes]): Bounding boxes of GT. + retained_gt_inds (Tensor): The gt indexes assigned as the + positive sample in the current prior. + offset (float): The offset of points, the value is normalized + with corresponding stride. Defaults to 0.5. + """ + if not isinstance(bboxes, Tensor): + # Convert HorizontalBoxes to Tensor + bboxes = bboxes.tensor + + # The PALETTE in the dataset_meta is required + assert self.dataset_meta is not None + palette = self.dataset_meta['PALETTE'] + x = ((grid_x_inds + offset) * stride).long() + y = ((grid_y_inds + offset) * stride).long() + center = torch.stack((x, y), dim=-1) + + retained_bboxes = bboxes[retained_gt_inds] + bbox_wh = retained_bboxes[:, 2:] - retained_bboxes[:, :2] + bbox_area = bbox_wh[:, 0] * bbox_wh[:, 1] + radius = _get_adaptive_scales(bbox_area) * 4 + colors = [palette[i] for i in class_inds] + + self.draw_circles( + center, + radius, + colors, + line_widths=0, + face_colors=colors, + alpha=1.0) + + def draw_prior(self, + grid_x_inds: Tensor, + grid_y_inds: Tensor, + class_inds: Tensor, + stride: int, + feat_ind: int, + prior_ind: int, + offset: float = 0.5): + """Draw priors on image. + + Args: + grid_x_inds (Tensor): The X-axis indexes of the positive sample + in current prior. + grid_y_inds (Tensor): The Y-axis indexes of the positive sample + in current prior. + class_inds (Tensor): The classes indexes of the positive sample + in current prior. + stride (int): Downsample factor of feature map. + feat_ind (int): Index of featmap. + prior_ind (int): Index of prior in current featmap. + offset (float): The offset of points, the value is normalized + with corresponding stride. Defaults to 0.5. + """ + + palette = self.dataset_meta['PALETTE'] + center_x = ((grid_x_inds + offset) * stride) + center_y = ((grid_y_inds + offset) * stride) + xyxy = torch.stack((center_x, center_y, center_x, center_y), dim=1) + assert self.priors_size is not None + xyxy += self.priors_size[feat_ind][prior_ind] + + colors = [palette[i] for i in class_inds] + self.draw_bboxes( + xyxy, + edge_colors=colors, + alpha=self.alpha, + line_styles='--', + line_widths=math.ceil(self.line_width * 0.3)) + + def draw_assign(self, + image: np.ndarray, + assign_results: List[List[dict]], + gt_instances: InstanceData, + show_prior: bool = False, + not_show_label: bool = False) -> np.ndarray: + """Draw assigning results. + + Args: + image (np.ndarray): The image to draw. + assign_results (list): The assigning results. + gt_instances (:obj:`InstanceData`): Data structure for + instance-level annotations or predictions. + show_prior (bool): Whether to show prior on image. + not_show_label (bool): Whether to show gt labels on images. + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + img_show_list = [] + for feat_ind, assign_results_feat in enumerate(assign_results): + img_show_list_feat = [] + for prior_ind, assign_results_prior in enumerate( + assign_results_feat): + self.set_image(image) + h, w = image.shape[:2] + + # draw grid + stride = assign_results_prior['stride'] + self.draw_grid(stride) + + # draw prior on matched gt + grid_x_inds = assign_results_prior['grid_x_inds'] + grid_y_inds = assign_results_prior['grid_y_inds'] + class_inds = assign_results_prior['class_inds'] + prior_ind = assign_results_prior['prior_ind'] + if show_prior: + self.draw_prior(grid_x_inds, grid_y_inds, class_inds, + stride, feat_ind, prior_ind) + + # draw matched gt + retained_gt_inds = assign_results_prior['retained_gt_inds'] + self.draw_instances_assign(gt_instances, retained_gt_inds, + not_show_label) + + # draw positive + self.draw_positive_assign(grid_x_inds, grid_y_inds, class_inds, + stride, gt_instances.bboxes, + retained_gt_inds) + + # draw title + base_prior = self.priors_size[feat_ind][prior_ind] + prior_size = (base_prior[2] - base_prior[0], + base_prior[3] - base_prior[1]) + pos = np.array((20, 20)) + text = f'feat_ind: {feat_ind} ' \ + f'prior_ind: {prior_ind} ' \ + f'prior_size: ({prior_size[0]}, {prior_size[1]})' + scales = _get_adaptive_scales(np.array([h * w / 16])) + font_sizes = int(13 * scales) + self.draw_texts( + text, + pos, + colors=self.text_color, + font_sizes=font_sizes, + bboxes=[{ + 'facecolor': 'black', + 'alpha': 0.8, + 'pad': 0.7, + 'edgecolor': 'none' + }]) + + img_show = self.get_image() + img_show = mmcv.impad(img_show, padding=(5, 5, 5, 5)) + img_show_list_feat.append(img_show) + img_show_list.append(np.concatenate(img_show_list_feat, axis=1)) + + # Merge all images into one image + return np.concatenate(img_show_list, axis=0)