From dfc17207baa812def8ca13d0e31e11650be2e1f9 Mon Sep 17 00:00:00 2001 From: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Date: Sun, 9 Oct 2022 12:45:17 +0800 Subject: [PATCH] [Vis] visualizer refine (#1411) * visualizer refine * updata docs --- mmocr/visualization/__init__.py | 5 +- mmocr/visualization/base_visualizer.py | 131 ++- mmocr/visualization/kie_visualizer.py | 201 +--- mmocr/visualization/textdet_visualizer.py | 140 ++- mmocr/visualization/textrecog_visualizer.py | 75 +- .../visualization/textspotting_visualizer.py | 89 +- mmocr/visualization/visualize.py | 890 ------------------ .../test_base_visualizer.py | 55 ++ .../test_visualization/test_kie_visualizer.py | 15 + .../test_textdet_visualizer.py | 4 + .../test_textrecog_visualizer.py | 10 +- .../test_textspotting_visualizer.py | 113 +++ 12 files changed, 487 insertions(+), 1241 deletions(-) delete mode 100644 mmocr/visualization/visualize.py create mode 100644 tests/test_visualization/test_base_visualizer.py create mode 100644 tests/test_visualization/test_textspotting_visualizer.py diff --git a/mmocr/visualization/__init__.py b/mmocr/visualization/__init__.py index 26081885..b070794b 100644 --- a/mmocr/visualization/__init__.py +++ b/mmocr/visualization/__init__.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .base_visualizer import BaseLocalVisualizer from .kie_visualizer import KIELocalVisualizer from .textdet_visualizer import TextDetLocalVisualizer from .textrecog_visualizer import TextRecogLocalVisualizer from .textspotting_visualizer import TextSpottingLocalVisualizer __all__ = [ - 'KIELocalVisualizer', 'TextDetLocalVisualizer', 'TextRecogLocalVisualizer', - 'TextSpottingLocalVisualizer' + 'BaseLocalVisualizer', 'KIELocalVisualizer', 'TextDetLocalVisualizer', + 'TextRecogLocalVisualizer', 'TextSpottingLocalVisualizer' ] diff --git a/mmocr/visualization/base_visualizer.py b/mmocr/visualization/base_visualizer.py index ffee8d3c..1501c6cb 100644 --- a/mmocr/visualization/base_visualizer.py +++ b/mmocr/visualization/base_visualizer.py @@ -50,14 +50,13 @@ class BaseLocalVisualizer(Visualizer): (95, 54, 80), (128, 76, 255), (201, 57, 1), (246, 0, 122), (191, 162, 208)] - @staticmethod - def _draw_labels(visualizer: Visualizer, - image: np.ndarray, - labels: Union[np.ndarray, torch.Tensor], - bboxes: Union[np.ndarray, torch.Tensor], - colors: Union[str, Sequence[str]] = 'k', - font_size: Union[int, float] = 10, - auto_font_size: bool = False) -> np.ndarray: + def get_labels_image(self, + image: np.ndarray, + labels: Union[np.ndarray, torch.Tensor], + bboxes: Union[np.ndarray, torch.Tensor], + colors: Union[str, Sequence[str]] = 'k', + font_size: Union[int, float] = 10, + auto_font_size: bool = False) -> np.ndarray: """Draw labels on image. Args: @@ -75,7 +74,7 @@ class BaseLocalVisualizer(Visualizer): auto_font_size (bool): Whether to automatically adjust font size. Defaults to False. """ - if colors is not None and isinstance(colors, Sequence): + if colors is not None and isinstance(colors, (list, tuple)): size = math.ceil(len(labels) / len(colors)) colors = (colors * size)[:len(labels)] if auto_font_size: @@ -83,68 +82,124 @@ class BaseLocalVisualizer(Visualizer): font_size, (int, float)) font_size = (bboxes[:, 2:] - bboxes[:, :2]).min(-1) * font_size font_size = font_size.tolist() - visualizer.set_image(image) - visualizer.draw_texts( + self.set_image(image) + self.draw_texts( labels, (bboxes[:, :2] + bboxes[:, 2:]) / 2, vertical_alignments='center', horizontal_alignments='center', colors='k', font_sizes=font_size) - return visualizer.get_image() + return self.get_image() - @staticmethod - def _draw_polygons(visualizer: Visualizer, - image: np.ndarray, - polygons: Sequence[np.ndarray], - colors: Union[str, Sequence[str]] = 'g', - filling: bool = False, - line_width: Union[int, float] = 0.5, - alpha: float = 0.5) -> np.ndarray: - if colors is not None and isinstance(colors, Sequence): + def get_polygons_image(self, + image: np.ndarray, + polygons: Sequence[np.ndarray], + colors: Union[str, Sequence[str]] = 'g', + filling: bool = False, + line_width: Union[int, float] = 0.5, + alpha: float = 0.5) -> np.ndarray: + """Draw polygons on image. + + Args: + image (np.ndarray): The origin image to draw. The format + should be RGB. + polygons (Sequence[np.ndarray]): The polygons to draw. The shape + should be (N, 2). + colors (Union[str, Sequence[str]]): The colors of polygons. + ``colors`` can have the same length with polygons or just + single value. If ``colors`` is single value, all the polygons + will have the same colors. Refer to `matplotlib.colors` for + full list of formats that are accepted. Defaults to 'g'. + filling (bool): Whether to fill the polygons. Defaults to False. + line_width (Union[int, float]): The line width of polygons. + Defaults to 0.5. + alpha (float): The alpha of polygons. Defaults to 0.5. + + Returns: + np.ndarray: The image with polygons drawn. + """ + if colors is not None and isinstance(colors, (list, tuple)): size = math.ceil(len(polygons) / len(colors)) colors = (colors * size)[:len(polygons)] - visualizer.set_image(image) + self.set_image(image) if filling: - visualizer.draw_polygons( + self.draw_polygons( polygons, face_colors=colors, edge_colors=colors, line_widths=line_width, alpha=alpha) else: - visualizer.draw_polygons( + self.draw_polygons( polygons, edge_colors=colors, line_widths=line_width, alpha=alpha) - return visualizer.get_image() + return self.get_image() - @staticmethod - def _draw_bboxes(visualizer: Visualizer, - image: np.ndarray, - bboxes: Union[np.ndarray, torch.Tensor], - colors: Union[str, Sequence[str]] = 'g', - filling: bool = False, - line_width: Union[int, float] = 0.5, - alpha: float = 0.5) -> np.ndarray: - if colors is not None and isinstance(colors, Sequence): + def get_bboxes_image(self: Visualizer, + image: np.ndarray, + bboxes: Union[np.ndarray, torch.Tensor], + colors: Union[str, Sequence[str]] = 'g', + filling: bool = False, + line_width: Union[int, float] = 0.5, + alpha: float = 0.5) -> np.ndarray: + """Draw bboxes on image. + + Args: + image (np.ndarray): The origin image to draw. The format + should be RGB. + bboxes (Union[np.ndarray, torch.Tensor]): The bboxes to draw. + colors (Union[str, Sequence[str]]): The colors of bboxes. + ``colors`` can have the same length with bboxes or just single + value. If ``colors`` is single value, all the bboxes will have + the same colors. Refer to `matplotlib.colors` for full list of + formats that are accepted. Defaults to 'g'. + filling (bool): Whether to fill the bboxes. Defaults to False. + line_width (Union[int, float]): The line width of bboxes. + Defaults to 0.5. + alpha (float): The alpha of bboxes. Defaults to 0.5. + + Returns: + np.ndarray: The image with bboxes drawn. + """ + if colors is not None and isinstance(colors, (list, tuple)): size = math.ceil(len(bboxes) / len(colors)) colors = (colors * size)[:len(bboxes)] - visualizer.set_image(image) + self.set_image(image) if filling: - visualizer.draw_bboxes( + self.draw_bboxes( bboxes, face_colors=colors, edge_colors=colors, line_widths=line_width, alpha=alpha) else: - visualizer.draw_bboxes( + self.draw_bboxes( bboxes, edge_colors=colors, line_widths=line_width, alpha=alpha) - return visualizer.get_image() + return self.get_image() def _draw_instances(self) -> np.ndarray: raise NotImplementedError + + def _cat_image(self, imgs: Sequence[np.ndarray], axis: int) -> np.ndarray: + """Concatenate images. + + Args: + imgs (Sequence[np.ndarray]): The images to concatenate. + axis (int): The axis to concatenate. + + Returns: + np.ndarray: The concatenated image. + """ + cat_image = list() + for img in imgs: + if img is not None: + cat_image.append(img) + if len(cat_image): + return np.concatenate(cat_image, axis=axis) + else: + return None diff --git a/mmocr/visualization/kie_visualizer.py b/mmocr/visualization/kie_visualizer.py index 25c2620c..b29cceb9 100644 --- a/mmocr/visualization/kie_visualizer.py +++ b/mmocr/visualization/kie_visualizer.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import math import warnings from typing import Dict, List, Optional, Sequence, Union @@ -15,31 +14,11 @@ from mmengine.visualization.utils import (check_type, check_type_and_length, from mmocr.registry import VISUALIZERS from mmocr.structures import KIEDataSample - -PALETTE = [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), - (106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), - (250, 170, 30), (100, 170, 30), (220, 220, 0), (175, 116, 175), - (250, 0, 30), (165, 42, 42), (255, 77, 255), (0, 226, 252), - (182, 182, 255), (0, 82, 0), (120, 166, 157), (110, 76, 0), - (174, 57, 255), (199, 100, 0), (72, 0, 118), (255, 179, 240), - (0, 125, 92), (209, 0, 151), (188, 208, 182), (0, 220, 176), - (255, 99, 164), (92, 0, 73), (133, 129, 255), (78, 180, 255), - (0, 228, 0), (174, 255, 243), (45, 89, 255), (134, 134, 103), - (145, 148, 174), (255, 208, 186), (197, 226, 255), (171, 134, 1), - (109, 63, 54), (207, 138, 255), (151, 0, 95), (9, 80, 61), - (84, 105, 51), (74, 65, 105), (166, 196, 102), (208, 195, 210), - (255, 109, 65), (0, 143, 149), (179, 0, 194), (209, 99, 106), - (5, 121, 0), (227, 255, 205), (147, 186, 208), (153, 69, 1), - (3, 95, 161), (163, 255, 0), (119, 0, 170), (0, 182, 199), - (0, 165, 120), (183, 130, 88), (95, 32, 0), (130, 114, 135), - (110, 129, 133), (166, 74, 118), (219, 142, 185), (79, 210, 114), - (178, 90, 62), (65, 70, 15), (127, 167, 115), (59, 105, 106), - (142, 108, 45), (196, 172, 0), (95, 54, 80), (128, 76, 255), - (201, 57, 1), (246, 0, 122), (191, 162, 208)] +from .base_visualizer import BaseLocalVisualizer @VISUALIZERS.register_module() -class KIELocalVisualizer(Visualizer): +class KIELocalVisualizer(BaseLocalVisualizer): """The MMOCR Text Detection Local Visualizer. Args: @@ -65,102 +44,6 @@ class KIELocalVisualizer(Visualizer): super().__init__(name=name, **kwargs) self.is_openset = is_openset - @staticmethod - def _draw_labels(visualizer: Visualizer, - image: np.ndarray, - labels: Union[np.ndarray, torch.Tensor], - bboxes: Union[np.ndarray, torch.Tensor], - colors: Union[str, Sequence[str]] = 'k', - font_size: Union[int, float] = 10, - auto_font_size: bool = False) -> np.ndarray: - """Draw labels on image. - - Args: - image (np.ndarray): The origin image to draw. The format - should be RGB. - labels (Union[np.ndarray, torch.Tensor]): The labels to draw. - bboxes (Union[np.ndarray, torch.Tensor]): The bboxes to draw. - colors (Union[str, Sequence[str]]): The colors of labels. - ``colors`` can have the same length with labels or just single - value. If ``colors`` is single value, all the labels will have - the same colors. Refer to `matplotlib.colors` for full list of - formats that are accepted. Defaults to 'k'. - font_size (Union[int, float]): The font size of labels. Defaults - to 10. - auto_font_size (bool): Whether to automatically adjust font size. - Defaults to False. - """ - if colors is not None and isinstance(colors, Sequence): - size = math.ceil(len(labels) / len(colors)) - colors = (colors * size)[:len(labels)] - if auto_font_size: - assert font_size is not None and isinstance( - font_size, (int, float)) - font_size = (bboxes[:, 2:] - bboxes[:, :2]).min(-1) * font_size - font_size = font_size.tolist() - visualizer.set_image(image) - visualizer.draw_texts( - labels, (bboxes[:, :2] + bboxes[:, 2:]) / 2, - vertical_alignments='center', - horizontal_alignments='center', - colors='k', - font_sizes=font_size) - return visualizer.get_image() - - @staticmethod - def _draw_polygons(visualizer: Visualizer, - image: np.ndarray, - polygons: Sequence[np.ndarray], - colors: Union[str, Sequence[str]] = 'g', - filling: bool = False, - line_width: Union[int, float] = 0.5, - alpha: float = 0.5) -> np.ndarray: - if colors is not None and isinstance(colors, Sequence): - size = math.ceil(len(polygons) / len(colors)) - colors = (colors * size)[:len(polygons)] - visualizer.set_image(image) - if filling: - visualizer.draw_polygons( - polygons, - face_colors=colors, - edge_colors=colors, - line_widths=line_width, - alpha=alpha) - else: - visualizer.draw_polygons( - polygons, - edge_colors=colors, - line_widths=line_width, - alpha=alpha) - return visualizer.get_image() - - @staticmethod - def _draw_bboxes(visualizer: Visualizer, - image: np.ndarray, - bboxes: Union[np.ndarray, torch.Tensor], - colors: Union[str, Sequence[str]] = 'g', - filling: bool = False, - line_width: Union[int, float] = 0.5, - alpha: float = 0.5) -> np.ndarray: - if colors is not None and isinstance(colors, Sequence): - size = math.ceil(len(bboxes) / len(colors)) - colors = (colors * size)[:len(bboxes)] - visualizer.set_image(image) - if filling: - visualizer.draw_bboxes( - bboxes, - face_colors=colors, - edge_colors=colors, - line_widths=line_width, - alpha=alpha) - else: - visualizer.draw_bboxes( - bboxes, - edge_colors=colors, - line_widths=line_width, - alpha=alpha) - return visualizer.get_image() - def _draw_edge_label(self, image: np.ndarray, edge_labels: Union[np.ndarray, torch.Tensor], @@ -182,6 +65,9 @@ class KIELocalVisualizer(Visualizer): arrow_colors (str, optional): The colors of arrows. Refer to `matplotlib.colors` for full list of formats that are accepted. Defaults to 'g'. + + Returns: + np.ndarray: The image with edge labels drawn. """ pairs = np.where(edge_labels > 0) key_bboxes = bboxes[pairs[0]] @@ -253,49 +139,45 @@ class KIELocalVisualizer(Visualizer): class_names (dict): The class names for bbox labels. is_openset (bool): Whether the dataset is openset. Defaults to False. + arrow_colors (str, optional): The colors of arrows. Refer to + `matplotlib.colors` for full list of formats that are accepted. + Defaults to 'g'. + + Returns: + np.ndarray: The image with instances drawn. """ img_shape = image.shape[:2] empty_shape = (img_shape[0], img_shape[1], 3) - if polygons: - polygons = [polygon.reshape(-1, 2) for polygon in polygons] - if polygons: - image = self._draw_polygons( - self, image, polygons, filling=True, colors=PALETTE) - else: - image = self._draw_bboxes( - self, image, bboxes, filling=True, colors=PALETTE) - text_image = np.full(empty_shape, 255, dtype=np.uint8) - text_image = self._draw_labels(self, text_image, texts, bboxes) - if polygons: - text_image = self._draw_polygons( - self, text_image, polygons, colors=PALETTE) - else: - text_image = self._draw_bboxes( - self, text_image, bboxes, colors=PALETTE) + text_image = self.get_labels_image(text_image, texts, bboxes) classes_image = np.full(empty_shape, 255, dtype=np.uint8) bbox_classes = [class_names[int(i)]['name'] for i in bbox_labels] - classes_image = self._draw_labels(self, classes_image, bbox_classes, - bboxes) + classes_image = self.get_labels_image(classes_image, bbox_classes, + bboxes) if polygons: - classes_image = self._draw_polygons( - self, classes_image, polygons, colors=PALETTE) + polygons = [polygon.reshape(-1, 2) for polygon in polygons] + image = self.get_polygons_image( + image, polygons, filling=True, colors=self.PALETTE) + text_image = self.get_polygons_image( + text_image, polygons, colors=self.PALETTE) + classes_image = self.get_polygons_image( + classes_image, polygons, colors=self.PALETTE) else: - classes_image = self._draw_bboxes( - self, classes_image, bboxes, colors=PALETTE) - - edge_image = None + image = self.get_bboxes_image( + image, bboxes, filling=True, colors=self.PALETTE) + text_image = self.get_bboxes_image( + text_image, bboxes, colors=self.PALETTE) + classes_image = self.get_bboxes_image( + classes_image, bboxes, colors=self.PALETTE) + cat_image = [image, text_image, classes_image] if is_openset: edge_image = np.full(empty_shape, 255, dtype=np.uint8) edge_image = self._draw_edge_label(edge_image, edge_labels, bboxes, texts, arrow_colors) - cat_image = [] - for i in [image, text_image, classes_image, edge_image]: - if i is not None: - cat_image.append(i) - return np.concatenate(cat_image, axis=1) + cat_image.append(edge_image) + return self._cat_image(cat_image, axis=1) def add_datasample(self, name: str, @@ -336,8 +218,7 @@ class KIELocalVisualizer(Visualizer): out_file (str): Path to output file. Defaults to None. step (int): Global step value to record. Defaults to 0. """ - gt_img_data = None - pred_img_data = None + cat_images = list() if draw_gt: gt_bboxes = data_sample.gt_instances.bboxes @@ -350,6 +231,7 @@ class KIELocalVisualizer(Visualizer): gt_texts, self.dataset_meta['category'], self.is_openset, 'g') + cat_images.append(gt_img_data) if draw_pred: gt_bboxes = data_sample.gt_instances.bboxes pred_labels = data_sample.pred_instances.labels @@ -362,22 +244,19 @@ class KIELocalVisualizer(Visualizer): gt_texts, self.dataset_meta['category'], self.is_openset, 'r') - if gt_img_data is not None and pred_img_data is not None: - drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=0) - elif gt_img_data is not None: - drawn_img = gt_img_data - elif pred_img_data is not None: - drawn_img = pred_img_data - else: - drawn_img = image + cat_images.append(pred_img_data) + + cat_images = self._cat_image(cat_images, axis=0) + if cat_images is None: + cat_images = image if show: - self.show(drawn_img, win_name=name, wait_time=wait_time) + self.show(cat_images, win_name=name, wait_time=wait_time) else: - self.add_image(name, drawn_img, step) + self.add_image(name, cat_images, step) if out_file is not None: - mmcv.imwrite(drawn_img[..., ::-1], out_file) + mmcv.imwrite(cat_images[..., ::-1], out_file) def draw_arrows(self, x_data: Union[np.ndarray, torch.Tensor], diff --git a/mmocr/visualization/textdet_visualizer.py b/mmocr/visualization/textdet_visualizer.py index 15209670..5f52074a 100644 --- a/mmocr/visualization/textdet_visualizer.py +++ b/mmocr/visualization/textdet_visualizer.py @@ -1,16 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import mmcv import numpy as np -from mmengine.visualization import Visualizer +import torch from mmocr.registry import VISUALIZERS from mmocr.structures import TextDetDataSample +from .base_visualizer import BaseLocalVisualizer @VISUALIZERS.register_module() -class TextDetLocalVisualizer(Visualizer): +class TextDetLocalVisualizer(BaseLocalVisualizer): """The MMOCR Text Detection Local Visualizer. Args: @@ -62,6 +63,42 @@ class TextDetLocalVisualizer(Visualizer): self.line_width = line_width self.alpha = alpha + def _draw_instances( + self, + image: np.ndarray, + bboxes: Union[np.ndarray, torch.Tensor], + polygons: Sequence[np.ndarray], + color: Union[str, Tuple, List[str], List[Tuple]] = 'g', + ) -> np.ndarray: + """Draw bboxes and polygons on image. + + Args: + image (np.ndarray): The origin image to draw. + bboxes (Union[np.ndarray, torch.Tensor]): The bboxes to draw. + polygons (Sequence[np.ndarray]): The polygons to draw. + color (Union[str, tuple, list[str], list[tuple]]): The + colors of polygons and bboxes. ``colors`` can have the same + length with lines or just single value. If ``colors`` is + single value, all the lines will have the same colors. Refer + to `matplotlib.colors` for full list of formats that are + accepted. Defaults to 'g'. + + Returns: + np.ndarray: The image with bboxes and polygons drawn. + """ + if polygons is not None and self.with_poly: + polygons = [polygon.reshape(-1, 2) for polygon in polygons] + image = self.get_polygons_image( + image, polygons, filling=True, colors=color, alpha=self.alpha) + if bboxes is not None and self.with_bbox: + image = self.get_bboxes_image( + image, + bboxes, + colors=color, + line_width=self.line_width, + alpha=self.alpha) + return image + def add_datasample(self, name: str, image: np.ndarray, @@ -101,79 +138,32 @@ class TextDetLocalVisualizer(Visualizer): and masks. Defaults to 0.3. step (int): Global step value to record. Defaults to 0. """ - gt_img_data = None - pred_img_data = None - - if (draw_gt and data_sample is not None - and 'gt_instances' in data_sample): - gt_instances = data_sample.gt_instances - - self.set_image(image) - - if self.with_poly and 'polygons' in gt_instances: - gt_polygons = gt_instances.polygons - gt_polygons = [ - gt_polygon.reshape(-1, 2) for gt_polygon in gt_polygons - ] - self.draw_polygons( - gt_polygons, - alpha=self.alpha, - edge_colors=self.gt_color, - line_widths=self.line_width) - - if self.with_bbox and 'bboxes' in gt_instances: - gt_bboxes = gt_instances.bboxes - self.draw_bboxes( - gt_bboxes, - alpha=self.alpha, - edge_colors=self.gt_color, - line_widths=self.line_width) - - gt_img_data = self.get_image() - - if draw_pred and data_sample is not None \ - and 'pred_instances' in data_sample: - pred_instances = data_sample.pred_instances - pred_instances = pred_instances[ - pred_instances.scores > pred_score_thr].cpu() - - self.set_image(image) - - if self.with_poly and 'polygons' in pred_instances: - pred_polygons = pred_instances.polygons - pred_polygons = [ - pred_polygon.reshape(-1, 2) - for pred_polygon in pred_polygons - ] - self.draw_polygons( - pred_polygons, - alpha=self.alpha, - edge_colors=self.pred_color, - line_widths=self.line_width) - - if self.with_bbox and 'bboxes' in pred_instances: - pred_bboxes = pred_instances.bboxes - self.draw_bboxes( - pred_bboxes, - alpha=self.alpha, - edge_colors=self.pred_color, - line_widths=self.line_width) - - pred_img_data = self.get_image() - - if gt_img_data is not None and pred_img_data is not None: - drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) - elif gt_img_data is not None: - drawn_img = gt_img_data - elif pred_img_data is not None: - drawn_img = pred_img_data - else: - drawn_img = image - + cat_images = [] + if data_sample is not None: + if draw_gt and 'gt_instances' in data_sample: + gt_instances = data_sample.gt_instances + gt_polygons = gt_instances.get('polygons', None) + gt_bboxes = gt_instances.get('bboxes', None) + gt_img_data = self._draw_instances(image.copy(), gt_bboxes, + gt_polygons, self.gt_color) + cat_images.append(gt_img_data) + if draw_pred and 'pred_instances' in data_sample: + pred_instances = data_sample.pred_instances + pred_instances = pred_instances[ + pred_instances.scores > pred_score_thr].cpu() + pred_polygons = pred_instances.get('polygons', None) + pred_bboxes = pred_instances.get('bboxes', None) + pred_img_data = self._draw_instances(image.copy(), pred_bboxes, + pred_polygons, + self.pred_color) + cat_images.append(pred_img_data) + cat_images = self._cat_image(cat_images, axis=1) + if cat_images is None: + cat_images = image if show: - self.show(drawn_img, win_name=name, wait_time=wait_time) + self.show(cat_images, win_name=name, wait_time=wait_time) else: - self.add_image(name, drawn_img, step) + self.add_image(name, cat_images, step) if out_file is not None: - mmcv.imwrite(drawn_img[..., ::-1], out_file) + mmcv.imwrite(cat_images[..., ::-1], out_file) diff --git a/mmocr/visualization/textrecog_visualizer.py b/mmocr/visualization/textrecog_visualizer.py index 5db03830..623bf764 100644 --- a/mmocr/visualization/textrecog_visualizer.py +++ b/mmocr/visualization/textrecog_visualizer.py @@ -4,14 +4,14 @@ from typing import Dict, Optional, Tuple, Union import cv2 import mmcv import numpy as np -from mmengine.visualization import Visualizer from mmocr.registry import VISUALIZERS from mmocr.structures import TextRecogDataSample +from .base_visualizer import BaseLocalVisualizer @VISUALIZERS.register_module() -class TextRecogLocalVisualizer(Visualizer): +class TextRecogLocalVisualizer(BaseLocalVisualizer): """MMOCR Text Detection Local Visualizer. Args: @@ -46,6 +46,30 @@ class TextRecogLocalVisualizer(Visualizer): self.gt_color = gt_color self.pred_color = pred_color + def _draw_instances(self, image: np.ndarray, text: str) -> np.ndarray: + """Draw text on image. + + Args: + image (np.ndarray): The image to draw. + text (str): The text to draw. + + Returns: + np.ndarray: The image with text drawn. + """ + height, width = image.shape[:2] + empty_img = np.full_like(image, 255) + self.set_image(empty_img) + font_size = 0.5 * width / (len(text) + 1) + self.draw_texts( + text, + np.array([width / 2, height / 2]), + colors=self.gt_color, + font_sizes=font_size, + vertical_alignments='center', + horizontal_alignments='center') + text_image = self.get_image() + return text_image + def add_datasample(self, name: str, image: np.ndarray, @@ -85,59 +109,28 @@ class TextRecogLocalVisualizer(Visualizer): pred_score_thr (float): Threshold of prediction score. It's not used in this function. Defaults to None. """ - gt_img_data = None - pred_img_data = None height, width = image.shape[:2] resize_height = 64 resize_width = int(1.0 * width / height * resize_height) image = cv2.resize(image, (resize_width, resize_height)) + if image.ndim == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + cat_images = [image] if draw_gt and data_sample is not None and 'gt_text' in data_sample: gt_text = data_sample.gt_text.item - empty_img = np.full_like(image, 255) - self.set_image(empty_img) - font_size = 0.5 * resize_width / (len(gt_text) + 1) - self.draw_texts( - gt_text, - np.array([resize_width / 2, resize_height / 2]), - colors=self.gt_color, - font_sizes=font_size, - vertical_alignments='center', - horizontal_alignments='center') - gt_text_image = self.get_image() - gt_img_data = np.concatenate((image, gt_text_image), axis=0) - + cat_images.append(self._draw_instances(image, gt_text)) if (draw_pred and data_sample is not None and 'pred_text' in data_sample): pred_text = data_sample.pred_text.item - empty_img = np.full_like(image, 255) - self.set_image(empty_img) - font_size = 0.5 * resize_width / (len(pred_text) + 1) - self.draw_texts( - pred_text, - np.array([resize_width / 2, resize_height / 2]), - colors=self.pred_color, - font_sizes=font_size, - vertical_alignments='center', - horizontal_alignments='center') - pred_text_image = self.get_image() - pred_img_data = np.concatenate((image, pred_text_image), axis=0) - - if gt_img_data is not None and pred_img_data is not None: - drawn_img = np.concatenate((gt_img_data, pred_text_image), axis=0) - elif gt_img_data is not None: - drawn_img = gt_img_data - elif pred_img_data is not None: - drawn_img = pred_img_data - else: - drawn_img = image + cat_images.append(self._draw_instances(image, pred_text)) + cat_images = self._cat_image(cat_images, axis=0) if show: - self.show(drawn_img, win_name=name, wait_time=wait_time) + self.show(cat_images, win_name=name, wait_time=wait_time) else: - self.add_image(name, drawn_img, step) + self.add_image(name, cat_images, step) if out_file is not None: - mmcv.imwrite(drawn_img[..., ::-1], out_file) + mmcv.imwrite(cat_images[..., ::-1], out_file) diff --git a/mmocr/visualization/textspotting_visualizer.py b/mmocr/visualization/textspotting_visualizer.py index 1571d88d..19a5e4ad 100644 --- a/mmocr/visualization/textspotting_visualizer.py +++ b/mmocr/visualization/textspotting_visualizer.py @@ -37,27 +37,26 @@ class TextSpottingLocalVisualizer(BaseLocalVisualizer): should be the same as the number of bboxes. class_names (dict): The class names for bbox labels. is_openset (bool): Whether the dataset is openset. Default: False. + + Returns: + np.ndarray: The image with instances drawn. """ img_shape = image.shape[:2] empty_shape = (img_shape[0], img_shape[1], 3) - + text_image = np.full(empty_shape, 255, dtype=np.uint8) + text_image = self.get_labels_image( + text_image, labels=texts, bboxes=bboxes) if polygons: polygons = [polygon.reshape(-1, 2) for polygon in polygons] - if polygons: - image = self._draw_polygons( - self, image, polygons, filling=True, colors=self.PALETTE) + image = self.get_polygons_image( + image, polygons, filling=True, colors=self.PALETTE) + text_image = self.get_polygons_image( + text_image, polygons, colors=self.PALETTE) else: - image = self._draw_bboxes( - self, image, bboxes, filling=True, colors=self.PALETTE) - - text_image = np.full(empty_shape, 255, dtype=np.uint8) - text_image = self._draw_labels(self, text_image, texts, bboxes) - if polygons: - text_image = self._draw_polygons( - self, text_image, polygons, colors=self.PALETTE) - else: - text_image = self._draw_bboxes( - self, text_image, bboxes, colors=self.PALETTE) + image = self.get_bboxes_image( + image, bboxes, filling=True, colors=self.PALETTE) + text_image = self.get_bboxes_image( + text_image, bboxes, colors=self.PALETTE) return np.concatenate([image, text_image], axis=1) def add_datasample(self, @@ -68,43 +67,69 @@ class TextSpottingLocalVisualizer(BaseLocalVisualizer): draw_pred: bool = True, show: bool = False, wait_time: int = 0, - pred_score_thr: float = None, + pred_score_thr: float = 0.5, out_file: Optional[str] = None, step: int = 0) -> None: - gt_img_data = None - pred_img_data = None + """Draw datasample and save to all backends. + + - If GT and prediction are plotted at the same time, they are + displayed in a stitched image where the left image is the + ground truth and the right image is the prediction. + - If ``show`` is True, all storage backends are ignored, and + the images will be displayed in a local window. + - If ``out_file`` is specified, the drawn image will be + saved to ``out_file``. This is usually used when the display + is not available. + + Args: + name (str): The image identifier. + image (np.ndarray): The image to draw. + data_sample (:obj:`TextSpottingDataSample`, optional): + TextDetDataSample which contains gt and prediction. Defaults + to None. + draw_gt (bool): Whether to draw GT TextDetDataSample. + Defaults to True. + draw_pred (bool): Whether to draw Predicted TextDetDataSample. + Defaults to True. + show (bool): Whether to display the drawn image. Default to False. + wait_time (float): The interval of show (s). Defaults to 0. + out_file (str): Path to output file. Defaults to None. + pred_score_thr (float): The threshold to visualize the bboxes + and masks. Defaults to 0.3. + step (int): Global step value to record. Defaults to 0. + """ + cat_images = [] if draw_gt: - gt_bboxes = data_sample.gt_instances.bboxes + gt_bboxes = data_sample.gt_instances.get('bboxes', None) gt_texts = data_sample.gt_instances.texts - gt_polygons = data_sample.gt_instances.polygons + gt_polygons = data_sample.gt_instances.get('polygons', None) gt_img_data = self._draw_instances(image, gt_bboxes, gt_polygons, gt_texts) + cat_images.append(gt_img_data) + if draw_pred: pred_instances = data_sample.pred_instances pred_instances = pred_instances[ pred_instances.scores > pred_score_thr].cpu().numpy() pred_bboxes = pred_instances.get('bboxes', None) pred_texts = pred_instances.texts - pred_polygons = pred_instances.polygons + pred_polygons = pred_instances.get('polygons', None) if pred_bboxes is None: pred_bboxes = [poly2bbox(poly) for poly in pred_polygons] pred_bboxes = np.array(pred_bboxes) pred_img_data = self._draw_instances(image, pred_bboxes, pred_polygons, pred_texts) - if gt_img_data is not None and pred_img_data is not None: - drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=0) - elif gt_img_data is not None: - drawn_img = gt_img_data - elif pred_img_data is not None: - drawn_img = pred_img_data - else: - drawn_img = image + cat_images.append(pred_img_data) + + cat_images = self._cat_image(cat_images, axis=0) + if cat_images is None: + cat_images = image if show: - self.show(drawn_img, win_name=name, wait_time=wait_time) + self.show(cat_images, win_name=name, wait_time=wait_time) else: - self.add_image(name, drawn_img, step) + self.add_image(name, cat_images, step) if out_file is not None: - mmcv.imwrite(drawn_img[..., ::-1], out_file) + mmcv.imwrite(cat_images[..., ::-1], out_file) diff --git a/mmocr/visualization/visualize.py b/mmocr/visualization/visualize.py deleted file mode 100644 index a8af6f34..00000000 --- a/mmocr/visualization/visualize.py +++ /dev/null @@ -1,890 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import math -import os -import shutil -import urllib -import warnings - -import cv2 -import mmcv -import mmengine -import numpy as np -import torch -from matplotlib import pyplot as plt -from PIL import Image, ImageDraw, ImageFont - -import mmocr.utils as utils - - -# TODO remove after KieVisualizer and TextSpotterVisualizer -def overlay_mask_img(img, mask): - """Draw mask boundaries on image for visualization. - - Args: - img (ndarray): The input image. - mask (ndarray): The instance mask. - - Returns: - img (ndarray): The output image with instance boundaries on it. - """ - assert isinstance(img, np.ndarray) - assert isinstance(mask, np.ndarray) - - contours, _ = cv2.findContours( - mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - - cv2.drawContours(img, contours, -1, (0, 255, 0), 1) - - return img - - -def show_feature(features, names, to_uint8, out_file=None): - """Visualize a list of feature maps. - - Args: - features (list(ndarray)): The feature map list. - names (list(str)): The visualized title list. - to_uint8 (list(1|0)): The list indicating whether to convent - feature maps to uint8. - out_file (str): The output file name. If set to None, - the output image will be shown without saving. - """ - assert utils.is_type_list(features, np.ndarray) - assert utils.is_type_list(names, str) - assert utils.is_type_list(to_uint8, int) - assert utils.is_none_or_type(out_file, str) - assert utils.equal_len(features, names, to_uint8) - - num = len(features) - row = col = math.ceil(math.sqrt(num)) - - for i, (f, n) in enumerate(zip(features, names)): - plt.subplot(row, col, i + 1) - plt.title(n) - if to_uint8[i]: - f = f.astype(np.uint8) - plt.imshow(f) - if out_file is None: - plt.show() - else: - plt.savefig(out_file) - - -def show_img_boundary(img, boundary): - """Show image and instance boundaires. - - Args: - img (ndarray): The input image. - boundary (list[float or int]): The input boundary. - """ - assert isinstance(img, np.ndarray) - assert utils.is_type_list(boundary, (int, float)) - - cv2.polylines( - img, [np.array(boundary).astype(np.int32).reshape(-1, 1, 2)], - True, - color=(0, 255, 0), - thickness=1) - plt.imshow(img) - plt.show() - - -def show_pred_gt(preds, - gts, - show=False, - win_name='', - wait_time=0, - out_file=None): - """Show detection and ground truth for one image. - - Args: - preds (list[list[float]]): The detection boundary list. - gts (list[list[float]]): The ground truth boundary list. - show (bool): Whether to show the image. - win_name (str): The window name. - wait_time (int): The value of waitKey param. - out_file (str): The filename of the output. - """ - assert utils.is_2dlist(preds) - assert utils.is_2dlist(gts) - assert isinstance(show, bool) - assert isinstance(win_name, str) - assert isinstance(wait_time, int) - assert utils.is_none_or_type(out_file, str) - - p_xy = [p for boundary in preds for p in boundary] - gt_xy = [g for gt in gts for g in gt] - - max_xy = np.max(np.array(p_xy + gt_xy).reshape(-1, 2), axis=0) - - width = int(max_xy[0]) + 100 - height = int(max_xy[1]) + 100 - - img = np.ones((height, width, 3), np.int8) * 255 - pred_color = mmcv.color_val('red') - gt_color = mmcv.color_val('blue') - thickness = 1 - - for boundary in preds: - cv2.polylines( - img, [np.array(boundary).astype(np.int32).reshape(-1, 1, 2)], - True, - color=pred_color, - thickness=thickness) - for gt in gts: - cv2.polylines( - img, [np.array(gt).astype(np.int32).reshape(-1, 1, 2)], - True, - color=gt_color, - thickness=thickness) - if show: - mmcv.imshow(img, win_name, wait_time) - if out_file is not None: - mmcv.imwrite(img, out_file) - - return img - - -def imshow_pred_boundary(img, - boundaries_with_scores, - labels, - score_thr=0, - boundary_color='blue', - text_color='blue', - thickness=1, - font_scale=0.5, - show=True, - win_name='', - wait_time=0, - out_file=None, - show_score=False): - """Draw boundaries and class labels (with scores) on an image. - - Args: - img (str or ndarray): The image to be displayed. - boundaries_with_scores (list[list[float]]): Boundaries with scores. - labels (list[int]): Labels of boundaries. - score_thr (float): Minimum score of boundaries to be shown. - boundary_color (str or tuple or :obj:`Color`): Color of boundaries. - text_color (str or tuple or :obj:`Color`): Color of texts. - thickness (int): Thickness of lines. - font_scale (float): Font scales of texts. - show (bool): Whether to show the image. - win_name (str): The window name. - wait_time (int): Value of waitKey param. - out_file (str or None): The filename of the output. - show_score (bool): Whether to show text instance score. - """ - assert isinstance(img, (str, np.ndarray)) - assert utils.is_2dlist(boundaries_with_scores) - assert utils.is_type_list(labels, int) - assert utils.equal_len(boundaries_with_scores, labels) - if len(boundaries_with_scores) == 0: - warnings.warn('0 text found in ' + out_file) - return None - - utils.valid_boundary(boundaries_with_scores[0]) - img = mmcv.imread(img) - - scores = np.array([b[-1] for b in boundaries_with_scores]) - inds = scores > score_thr - boundaries = [boundaries_with_scores[i][:-1] for i in np.where(inds)[0]] - scores = [scores[i] for i in np.where(inds)[0]] - labels = [labels[i] for i in np.where(inds)[0]] - - boundary_color = mmcv.color_val(boundary_color) - text_color = mmcv.color_val(text_color) - font_scale = 0.5 - - for boundary, score in zip(boundaries, scores): - boundary_int = np.array(boundary).astype(np.int32) - - cv2.polylines( - img, [boundary_int.reshape(-1, 1, 2)], - True, - color=boundary_color, - thickness=thickness) - - if show_score: - label_text = f'{score:.02f}' - cv2.putText(img, label_text, - (boundary_int[0], boundary_int[1] - 2), - cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color) - if show: - mmcv.imshow(img, win_name, wait_time) - if out_file is not None: - mmcv.imwrite(img, out_file) - - return img - - -def imshow_text_char_boundary(img, - text_quads, - boundaries, - char_quads, - chars, - show=False, - thickness=1, - font_scale=0.5, - win_name='', - wait_time=-1, - out_file=None): - """Draw text boxes and char boxes on img. - - Args: - img (str or ndarray): The img to be displayed. - text_quads (list[list[int|float]]): The text boxes. - boundaries (list[list[int|float]]): The boundary list. - char_quads (list[list[list[int|float]]]): A 2d list of char boxes. - char_quads[i] is for the ith text, and char_quads[i][j] is the jth - char of the ith text. - chars (list[list[char]]). The string for each text box. - thickness (int): Thickness of lines. - font_scale (float): Font scales of texts. - show (bool): Whether to show the image. - win_name (str): The window name. - wait_time (int): Value of waitKey param. - out_file (str or None): The filename of the output. - """ - assert isinstance(img, (np.ndarray, str)) - assert utils.is_2dlist(text_quads) - assert utils.is_2dlist(boundaries) - assert utils.is_3dlist(char_quads) - assert utils.is_2dlist(chars) - assert utils.equal_len(text_quads, char_quads, boundaries) - - img = mmcv.imread(img) - char_color = [mmcv.color_val('blue'), mmcv.color_val('green')] - text_color = mmcv.color_val('red') - text_inx = 0 - for text_box, boundary, char_box, txt in zip(text_quads, boundaries, - char_quads, chars): - text_box = np.array(text_box) - boundary = np.array(boundary) - - text_box = text_box.reshape(-1, 2).astype(np.int32) - cv2.polylines( - img, [text_box.reshape(-1, 1, 2)], - True, - color=text_color, - thickness=thickness) - if boundary.shape[0] > 0: - cv2.polylines( - img, [boundary.reshape(-1, 1, 2)], - True, - color=text_color, - thickness=thickness) - - for b in char_box: - b = np.array(b) - c = char_color[text_inx % 2] - b = b.astype(np.int32) - cv2.polylines( - img, [b.reshape(-1, 1, 2)], True, color=c, thickness=thickness) - - label_text = ''.join(txt) - cv2.putText(img, label_text, (text_box[0, 0], text_box[0, 1] - 2), - cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color) - text_inx = text_inx + 1 - - if show: - mmcv.imshow(img, win_name, wait_time) - if out_file is not None: - mmcv.imwrite(img, out_file) - - return img - - -def tile_image(images): - """Combined multiple images to one vertically. - - Args: - images (list[np.ndarray]): Images to be combined. - """ - assert isinstance(images, list) - assert len(images) > 0 - - for i, _ in enumerate(images): - if len(images[i].shape) == 2: - images[i] = cv2.cvtColor(images[i], cv2.COLOR_GRAY2BGR) - - widths = [img.shape[1] for img in images] - heights = [img.shape[0] for img in images] - h, w = sum(heights), max(widths) - vis_img = np.zeros((h, w, 3), dtype=np.uint8) - - offset_y = 0 - for image in images: - img_h, img_w = image.shape[:2] - vis_img[offset_y:(offset_y + img_h), 0:img_w, :] = image - offset_y += img_h - - return vis_img - - -def imshow_text_label(img, - pred_label, - gt_label, - show=False, - win_name='', - wait_time=-1, - out_file=None): - """Draw predicted texts and ground truth texts on images. - - Args: - img (str or np.ndarray): Image filename or loaded image. - pred_label (str): Predicted texts. - gt_label (str): Ground truth texts. - show (bool): Whether to show the image. - win_name (str): The window name. - wait_time (int): Value of waitKey param. - out_file (str): The filename of the output. - """ - assert isinstance(img, (np.ndarray, str)) - assert isinstance(pred_label, str) - assert isinstance(gt_label, str) - assert isinstance(show, bool) - assert isinstance(win_name, str) - assert isinstance(wait_time, int) - - img = mmcv.imread(img) - - src_h, src_w = img.shape[:2] - resize_height = 64 - resize_width = int(1.0 * src_w / src_h * resize_height) - img = cv2.resize(img, (resize_width, resize_height)) - h, w = img.shape[:2] - - if is_contain_chinese(pred_label): - pred_img = draw_texts_by_pil(img, [pred_label], None) - else: - pred_img = np.ones((h, w, 3), dtype=np.uint8) * 255 - cv2.putText(pred_img, pred_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX, - 0.9, (0, 0, 255), 2) - images = [pred_img, img] - - if gt_label != '': - if is_contain_chinese(gt_label): - gt_img = draw_texts_by_pil(img, [gt_label], None) - else: - gt_img = np.ones((h, w, 3), dtype=np.uint8) * 255 - cv2.putText(gt_img, gt_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX, - 0.9, (255, 0, 0), 2) - images.append(gt_img) - - img = tile_image(images) - - if show: - mmcv.imshow(img, win_name, wait_time) - if out_file is not None: - mmcv.imwrite(img, out_file) - - return img - - -def imshow_node(img, - result, - boxes, - idx_to_cls={}, - show=False, - win_name='', - wait_time=-1, - out_file=None): - - img = mmcv.imread(img) - h, w = img.shape[:2] - - max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1) - node_pred_label = max_idx.numpy().tolist() - node_pred_score = max_value.numpy().tolist() - - texts, text_boxes = [], [] - for i, box in enumerate(boxes): - new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]], - [box[0], box[3]]] - Pts = np.array([new_box], np.int32) - cv2.polylines( - img, [Pts.reshape((-1, 1, 2))], - True, - color=(255, 255, 0), - thickness=1) - x_min = int(min(point[0] for point in new_box)) - y_min = int(min(point[1] for point in new_box)) - - # text - pred_label = str(node_pred_label[i]) - if pred_label in idx_to_cls: - pred_label = idx_to_cls[pred_label] - pred_score = f'{node_pred_score[i]:.2f}' - text = pred_label + '(' + pred_score + ')' - texts.append(text) - - # text box - font_size = int( - min( - abs(new_box[3][1] - new_box[0][1]), - abs(new_box[1][0] - new_box[0][0]))) - char_num = len(text) - text_box = [ - x_min * 2, y_min, x_min * 2 + font_size * char_num, y_min, - x_min * 2 + font_size * char_num, y_min + font_size, x_min * 2, - y_min + font_size - ] - text_boxes.append(text_box) - - pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255 - pred_img = draw_texts_by_pil( - pred_img, texts, text_boxes, draw_box=False, on_ori_img=True) - - vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255 - vis_img[:, :w] = img - vis_img[:, w:] = pred_img - - if show: - mmcv.imshow(vis_img, win_name, wait_time) - if out_file is not None: - mmcv.imwrite(vis_img, out_file) - - return vis_img - - -def gen_color(): - """Generate BGR color schemes.""" - color_list = [(101, 67, 254), (154, 157, 252), (173, 205, 249), - (123, 151, 138), (187, 200, 178), (148, 137, 69), - (169, 200, 200), (155, 175, 131), (154, 194, 182), - (178, 190, 137), (140, 211, 222), (83, 156, 222)] - return color_list - - -def draw_polygons(img, polys): - """Draw polygons on image. - - Args: - img (np.ndarray): The original image. - polys (list[list[float]]): Detected polygons. - Return: - out_img (np.ndarray): Visualized image. - """ - dst_img = img.copy() - color_list = gen_color() - out_img = dst_img - for idx, poly in enumerate(polys): - poly = np.array(poly).reshape((-1, 1, 2)).astype(np.int32) - cv2.drawContours( - img, - np.array([poly]), - -1, - color_list[idx % len(color_list)], - thickness=cv2.FILLED) - out_img = cv2.addWeighted(dst_img, 0.5, img, 0.5, 0) - return out_img - - -def get_optimal_font_scale(text, width): - """Get optimal font scale for cv2.putText. - - Args: - text (str): Text in one box. - width (int): The box width. - """ - for scale in reversed(range(0, 60, 1)): - textSize = cv2.getTextSize( - text, - fontFace=cv2.FONT_HERSHEY_SIMPLEX, - fontScale=scale / 10, - thickness=1) - new_width = textSize[0][0] - if new_width <= width: - return scale / 10 - return 1 - - -def draw_texts(img, texts, boxes=None, draw_box=True, on_ori_img=False): - """Draw boxes and texts on empty img. - - Args: - img (np.ndarray): The original image. - texts (list[str]): Recognized texts. - boxes (list[list[float]]): Detected bounding boxes. - draw_box (bool): Whether draw box or not. If False, draw text only. - on_ori_img (bool): If True, draw box and text on input image, - else, on a new empty image. - Return: - out_img (np.ndarray): Visualized image. - """ - color_list = gen_color() - h, w = img.shape[:2] - if boxes is None: - boxes = [[0, 0, w, 0, w, h, 0, h]] - assert len(texts) == len(boxes) - - if on_ori_img: - out_img = img - else: - out_img = np.ones((h, w, 3), dtype=np.uint8) * 255 - for idx, (box, text) in enumerate(zip(boxes, texts)): - if draw_box: - new_box = [[x, y] for x, y in zip(box[0::2], box[1::2])] - Pts = np.array([new_box], np.int32) - cv2.polylines( - out_img, [Pts.reshape((-1, 1, 2))], - True, - color=color_list[idx % len(color_list)], - thickness=1) - min_x = int(min(box[0::2])) - max_y = int( - np.mean(np.array(box[1::2])) + 0.2 * - (max(box[1::2]) - min(box[1::2]))) - font_scale = get_optimal_font_scale( - text, int(max(box[0::2]) - min(box[0::2]))) - cv2.putText(out_img, text, (min_x, max_y), cv2.FONT_HERSHEY_SIMPLEX, - font_scale, (0, 0, 0), 1) - - return out_img - - -def draw_texts_by_pil(img, - texts, - boxes=None, - draw_box=True, - on_ori_img=False, - font_size=None, - fill_color=None, - draw_pos=None, - return_text_size=False): - """Draw boxes and texts on empty image, especially for Chinese. - - Args: - img (np.ndarray): The original image. - texts (list[str]): Recognized texts. - boxes (list[list[float]]): Detected bounding boxes. - draw_box (bool): Whether draw box or not. If False, draw text only. - on_ori_img (bool): If True, draw box and text on input image, - else on a new empty image. - font_size (int, optional): Size to create a font object for a font. - fill_color (tuple(int), optional): Fill color for text. - draw_pos (list[tuple(int)], optional): Start point to draw each text. - return_text_size (bool): If True, return the list of text size. - - Returns: - (np.ndarray, list[tuple]) or np.ndarray: Return a tuple - ``(out_img, text_sizes)``, where ``out_img`` is the output image - with texts drawn on it and ``text_sizes`` are the size of drawing - texts. If ``return_text_size`` is False, only the output image will be - returned. - """ - - color_list = gen_color() - h, w = img.shape[:2] - if boxes is None: - boxes = [[0, 0, w, 0, w, h, 0, h]] - if draw_pos is None: - draw_pos = [None for _ in texts] - assert len(boxes) == len(texts) == len(draw_pos) - - if fill_color is None: - fill_color = (0, 0, 0) - - if on_ori_img: - out_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) - else: - out_img = Image.new('RGB', (w, h), color=(255, 255, 255)) - out_draw = ImageDraw.Draw(out_img) - - text_sizes = [] - for idx, (box, text, ori_point) in enumerate(zip(boxes, texts, draw_pos)): - if len(text) == 0: - continue - min_x, max_x = min(box[0::2]), max(box[0::2]) - min_y, max_y = min(box[1::2]), max(box[1::2]) - color = tuple(list(color_list[idx % len(color_list)])[::-1]) - if draw_box: - out_draw.line(box, fill=color, width=1) - dirname, _ = os.path.split(os.path.abspath(__file__)) - font_path = os.path.join(dirname, 'font.TTF') - if not os.path.exists(font_path): - url = ('https://download.openmmlab.com/mmocr/data/font.TTF') - print(f'Downloading {url} ...') - local_filename, _ = urllib.request.urlretrieve(url) - shutil.move(local_filename, font_path) - tmp_font_size = font_size - if tmp_font_size is None: - box_width = max(max_x - min_x, max_y - min_y) - tmp_font_size = int(0.9 * box_width / len(text)) - fnt = ImageFont.truetype(font_path, tmp_font_size) - if ori_point is None: - ori_point = (min_x + 1, min_y + 1) - out_draw.text(ori_point, text, font=fnt, fill=fill_color) - text_sizes.append(fnt.getsize(text)) - - del out_draw - - out_img = cv2.cvtColor(np.asarray(out_img), cv2.COLOR_RGB2BGR) - - if return_text_size: - return out_img, text_sizes - - return out_img - - -def is_contain_chinese(check_str): - """Check whether string contains Chinese or not. - - Args: - check_str (str): String to be checked. - - Return True if contains Chinese, else False. - """ - for ch in check_str: - if '\u4e00' <= ch <= '\u9fff': - return True - return False - - -def det_recog_show_result(img, end2end_res, out_file=None): - """Draw `result`(boxes and texts) on `img`. - - Args: - img (str or np.ndarray): The image to be displayed. - end2end_res (dict): Text detect and recognize results. - out_file (str): Image path where the visualized image should be saved. - Return: - out_img (np.ndarray): Visualized image. - """ - img = mmcv.imread(img) - boxes, texts = [], [] - for res in end2end_res['result']: - boxes.append(res['box']) - texts.append(res['text']) - box_vis_img = draw_polygons(img, boxes) - - if is_contain_chinese(''.join(texts)): - text_vis_img = draw_texts_by_pil(img, texts, boxes) - else: - text_vis_img = draw_texts(img, texts, boxes) - - h, w = img.shape[:2] - out_img = np.ones((h, w * 2, 3), dtype=np.uint8) - out_img[:, :w, :] = box_vis_img - out_img[:, w:, :] = text_vis_img - - if out_file: - mmcv.imwrite(out_img, out_file) - - return out_img - - -def draw_edge_result(img, result, edge_thresh=0.5, keynode_thresh=0.5): - """Draw text and their relationship on empty images. - - Args: - img (np.ndarray): The original image. - result (dict): The result of model forward_test, including: - - img_metas (list[dict]): List of meta information dictionary. - - nodes (Tensor): Node prediction with size: - number_node * node_classes. - - edges (Tensor): Edge prediction with size: number_edge * 2. - edge_thresh (float): Score threshold for edge classification. - keynode_thresh (float): Score threshold for node - (``key``) classification. - - Returns: - np.ndarray: The image with key, value and relation drawn on it. - """ - - h, w = img.shape[:2] - - vis_area_width = w // 3 * 2 - vis_area_height = h - dist_key_to_value = vis_area_width // 2 - dist_pair_to_pair = 30 - - bbox_x1 = dist_pair_to_pair - bbox_y1 = 0 - - new_w = vis_area_width - new_h = vis_area_height - pred_edge_img = np.ones((new_h, new_w, 3), dtype=np.uint8) * 255 - - nodes = result['nodes'].detach().cpu() - texts = result['img_metas'][0]['ori_texts'] - num_nodes = result['nodes'].size(0) - edges = result['edges'].detach().cpu()[:, -1].view(num_nodes, num_nodes) - - # (i, j) will be a valid pair - # either edge_score(node_i->node_j) > edge_thresh - # or edge_score(node_j->node_i) > edge_thresh - pairs = (torch.max(edges, edges.T) > edge_thresh).nonzero(as_tuple=True) - pairs = (pairs[0].numpy().tolist(), pairs[1].numpy().tolist()) - - # 1. "for n1, n2 in zip(*pairs) if n1 < n2": - # Only (n1, n2) will be included if n1 < n2 but not (n2, n1), to - # avoid duplication. - # 2. "(n1, n2) if nodes[n1, 1] > nodes[n1, 2]": - # nodes[n1, 1] is the score that this node is predicted as key, - # nodes[n1, 2] is the score that this node is predicted as value. - # If nodes[n1, 1] > nodes[n1, 2], n1 will be the index of key, - # so that n2 will be the index of value. - result_pairs = [(n1, n2) if nodes[n1, 1] > nodes[n1, 2] else (n2, n1) - for n1, n2 in zip(*pairs) if n1 < n2] - - result_pairs.sort() - result_pairs_score = [ - torch.max(edges[n1, n2], edges[n2, n1]) for n1, n2 in result_pairs - ] - - key_current_idx = -1 - pos_current = (-1, -1) - newline_flag = False - - key_font_size = 15 - value_font_size = 15 - key_font_color = (0, 0, 0) - value_font_color = (0, 0, 255) - arrow_color = (0, 0, 255) - score_color = (0, 255, 0) - for pair, pair_score in zip(result_pairs, result_pairs_score): - key_idx = pair[0] - if nodes[key_idx, 1] < keynode_thresh: - continue - if key_idx != key_current_idx: - # move y-coords down for a new key - bbox_y1 += 10 - # enlarge blank area to show key-value info - if newline_flag: - bbox_x1 += vis_area_width - tmp_img = np.ones( - (new_h, new_w + vis_area_width, 3), dtype=np.uint8) * 255 - tmp_img[:new_h, :new_w] = pred_edge_img - pred_edge_img = tmp_img - new_w += vis_area_width - newline_flag = False - bbox_y1 = 10 - key_text = texts[key_idx] - key_pos = (bbox_x1, bbox_y1) - value_idx = pair[1] - value_text = texts[value_idx] - value_pos = (bbox_x1 + dist_key_to_value, bbox_y1) - if key_idx != key_current_idx: - # draw text for a new key - key_current_idx = key_idx - pred_edge_img, text_sizes = draw_texts_by_pil( - pred_edge_img, [key_text], - draw_box=False, - on_ori_img=True, - font_size=key_font_size, - fill_color=key_font_color, - draw_pos=[key_pos], - return_text_size=True) - pos_right_bottom = (key_pos[0] + text_sizes[0][0], - key_pos[1] + text_sizes[0][1]) - pos_current = (pos_right_bottom[0] + 5, bbox_y1 + 10) - pred_edge_img = cv2.arrowedLine( - pred_edge_img, (pos_right_bottom[0] + 5, bbox_y1 + 10), - (bbox_x1 + dist_key_to_value - 5, bbox_y1 + 10), arrow_color, - 1) - score_pos_x = int( - (pos_right_bottom[0] + bbox_x1 + dist_key_to_value) / 2.) - score_pos_y = bbox_y1 + 10 - int(key_font_size * 0.3) - else: - # draw arrow from key to value - if newline_flag: - tmp_img = np.ones((new_h + dist_pair_to_pair, new_w, 3), - dtype=np.uint8) * 255 - tmp_img[:new_h, :new_w] = pred_edge_img - pred_edge_img = tmp_img - new_h += dist_pair_to_pair - pred_edge_img = cv2.arrowedLine(pred_edge_img, pos_current, - (bbox_x1 + dist_key_to_value - 5, - bbox_y1 + 10), arrow_color, 1) - score_pos_x = int( - (pos_current[0] + bbox_x1 + dist_key_to_value - 5) / 2.) - score_pos_y = int((pos_current[1] + bbox_y1 + 10) / 2.) - # draw edge score - cv2.putText(pred_edge_img, f'{pair_score:.2f}', - (score_pos_x, score_pos_y), cv2.FONT_HERSHEY_COMPLEX, 0.4, - score_color) - # draw text for value - pred_edge_img = draw_texts_by_pil( - pred_edge_img, [value_text], - draw_box=False, - on_ori_img=True, - font_size=value_font_size, - fill_color=value_font_color, - draw_pos=[value_pos], - return_text_size=False) - bbox_y1 += dist_pair_to_pair - if bbox_y1 + dist_pair_to_pair >= new_h: - newline_flag = True - - return pred_edge_img - - -def imshow_edge(img, - result, - boxes, - show=False, - win_name='', - wait_time=-1, - out_file=None): - """Display the prediction results of the nodes and edges of the KIE model. - - Args: - img (np.ndarray): The original image. - result (dict): The result of model forward_test, including: - - img_metas (list[dict]): List of meta information dictionary. - - nodes (Tensor): Node prediction with size: \ - number_node * node_classes. - - edges (Tensor): Edge prediction with size: number_edge * 2. - boxes (list): The text boxes corresponding to the nodes. - show (bool): Whether to show the image. Default: False. - win_name (str): The window name. Default: '' - wait_time (float): Value of waitKey param. Default: 0. - out_file (str or None): The filename to write the image. - Default: None. - - Returns: - np.ndarray: The image with key, value and relation drawn on it. - """ - img = mmcv.imread(img) - h, w = img.shape[:2] - color_list = gen_color() - - for i, box in enumerate(boxes): - new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]], - [box[0], box[3]]] - Pts = np.array([new_box], np.int32) - cv2.polylines( - img, [Pts.reshape((-1, 1, 2))], - True, - color=color_list[i % len(color_list)], - thickness=1) - - pred_img_h = h - pred_img_w = w - - pred_edge_img = draw_edge_result(img, result) - pred_img_h = max(pred_img_h, pred_edge_img.shape[0]) - pred_img_w += pred_edge_img.shape[1] - - vis_img = np.zeros((pred_img_h, pred_img_w, 3), dtype=np.uint8) - vis_img[:h, :w] = img - vis_img[:, w:] = 255 - - height_t, width_t = pred_edge_img.shape[:2] - vis_img[:height_t, w:(w + width_t)] = pred_edge_img - - if show: - mmcv.imshow(vis_img, win_name, wait_time) - if out_file is not None: - mmcv.imwrite(vis_img, out_file) - res_dic = { - 'boxes': boxes, - 'nodes': result['nodes'].detach().cpu(), - 'edges': result['edges'].detach().cpu(), - 'metas': result['img_metas'][0] - } - mmengine.dump(res_dic, f'{out_file}_res.pkl') - - return vis_img diff --git a/tests/test_visualization/test_base_visualizer.py b/tests/test_visualization/test_base_visualizer.py new file mode 100644 index 00000000..57abc242 --- /dev/null +++ b/tests/test_visualization/test_base_visualizer.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np + +from mmocr.visualization import BaseLocalVisualizer + + +class TestBaseLocalVisualizer(TestCase): + + def test_get_labels_image(self): + labels = ['a', 'b', 'c'] + image = np.zeros((40, 40, 3), dtype=np.uint8) + bboxes = np.array([[0, 0, 10, 10], [10, 10, 20, 20], [20, 20, 30, 30]]) + labels_image = BaseLocalVisualizer().get_labels_image( + image, + labels, + bboxes=bboxes, + auto_font_size=True, + colors=['r', 'r', 'r', 'r']) + self.assertEqual(labels_image.shape, (40, 40, 3)) + + def test_get_polygons_image(self): + polygons = [np.array([0, 0, 10, 10, 20, 20, 30, 30]).reshape(-1, 2)] + image = np.zeros((40, 40, 3), dtype=np.uint8) + polygons_image = BaseLocalVisualizer().get_polygons_image( + image, polygons, colors=['r', 'r', 'r', 'r']) + self.assertEqual(polygons_image.shape, (40, 40, 3)) + + polygons_image = BaseLocalVisualizer().get_polygons_image( + image, polygons, colors=['r', 'r', 'r', 'r'], filling=True) + self.assertEqual(polygons_image.shape, (40, 40, 3)) + + def test_get_bboxes_image(self): + bboxes = np.array([[0, 0, 10, 10], [10, 10, 20, 20], [20, 20, 30, 30]]) + image = np.zeros((40, 40, 3), dtype=np.uint8) + bboxes_image = BaseLocalVisualizer().get_bboxes_image( + image, bboxes, colors=['r', 'r', 'r', 'r']) + self.assertEqual(bboxes_image.shape, (40, 40, 3)) + + bboxes_image = BaseLocalVisualizer().get_bboxes_image( + image, bboxes, colors=['r', 'r', 'r', 'r'], filling=True) + self.assertEqual(bboxes_image.shape, (40, 40, 3)) + + def test_cat_images(self): + image1 = np.zeros((40, 40, 3), dtype=np.uint8) + image2 = np.zeros((40, 40, 3), dtype=np.uint8) + image = BaseLocalVisualizer()._cat_image([image1, image2], axis=1) + self.assertEqual(image.shape, (40, 80, 3)) + + image = BaseLocalVisualizer()._cat_image([], axis=0) + self.assertIsNone(image) + + image = BaseLocalVisualizer()._cat_image([image1, None], axis=0) + self.assertEqual(image.shape, (40, 40, 3)) diff --git a/tests/test_visualization/test_kie_visualizer.py b/tests/test_visualization/test_kie_visualizer.py index 5237d6b4..0cc650b3 100644 --- a/tests/test_visualization/test_kie_visualizer.py +++ b/tests/test_visualization/test_kie_visualizer.py @@ -105,6 +105,21 @@ class TestTextKIELocalVisualizer(unittest.TestCase): out_file=out_file) self._assert_image_and_shape(out_file, (h, w * 4, c)) + visualizer = KIELocalVisualizer(is_openset=False) + visualizer.dataset_meta = dict(category=[ + dict(id=0, name='bg'), + dict(id=1, name='key'), + dict(id=2, name='value'), + dict(id=3, name='other') + ]) + visualizer.add_datasample( + 'image', + image, + self.data_sample, + draw_pred=False, + out_file=out_file) + self._assert_image_and_shape(out_file, (h, w * 3, c)) + def _assert_image_and_shape(self, out_file, out_shape): self.assertTrue(osp.exists(out_file)) drawn_img = cv2.imread(out_file) diff --git a/tests/test_visualization/test_textdet_visualizer.py b/tests/test_visualization/test_textdet_visualizer.py index c6da4901..21a493ad 100644 --- a/tests/test_visualization/test_textdet_visualizer.py +++ b/tests/test_visualization/test_textdet_visualizer.py @@ -101,6 +101,10 @@ class TestTextDetLocalVisualizer(unittest.TestCase): out_file=out_file) self._assert_image_and_shape(out_file, (h, w, c)) + det_local_visualizer.add_datasample( + 'image', image, None, out_file=out_file) + self._assert_image_and_shape(out_file, (h, w, c)) + def _assert_image_and_shape(self, out_file, out_shape): self.assertTrue(osp.exists(out_file)) drawn_img = cv2.imread(out_file) diff --git a/tests/test_visualization/test_textrecog_visualizer.py b/tests/test_visualization/test_textrecog_visualizer.py index 1154f770..3171a02d 100644 --- a/tests/test_visualization/test_textrecog_visualizer.py +++ b/tests/test_visualization/test_textrecog_visualizer.py @@ -46,7 +46,7 @@ class TestTextDetLocalVisualizer(unittest.TestCase): draw_pred=False) self._assert_image_and_shape(out_file, (h * 2, w, 3)) - # draw_gt = True + gt_sample + pred_sample + # draw_gt = True recog_local_visualizer.add_datasample( 'image', image, @@ -56,7 +56,13 @@ class TestTextDetLocalVisualizer(unittest.TestCase): draw_pred=True) self._assert_image_and_shape(out_file, (h * 3, w, 3)) - # draw_gt = False + gt_sample + pred_sample + # draw_gt = False + recog_local_visualizer.add_datasample( + 'image', image, data_sample, draw_gt=False, out_file=out_file) + self._assert_image_and_shape(out_file, (h * 2, w, 3)) + + # gray image + image = np.random.randint(0, 256, size=(h, w)).astype('uint8') recog_local_visualizer.add_datasample( 'image', image, data_sample, draw_gt=False, out_file=out_file) self._assert_image_and_shape(out_file, (h * 2, w, 3)) diff --git a/tests/test_visualization/test_textspotting_visualizer.py b/tests/test_visualization/test_textspotting_visualizer.py new file mode 100644 index 00000000..91086475 --- /dev/null +++ b/tests/test_visualization/test_textspotting_visualizer.py @@ -0,0 +1,113 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import tempfile +import unittest + +import cv2 +import numpy as np +import torch +from mmengine.structures import InstanceData + +from mmocr.structures import TextDetDataSample +from mmocr.utils import bbox2poly +from mmocr.visualization import TextSpottingLocalVisualizer + + +class TestTextKIELocalVisualizer(unittest.TestCase): + + def setUp(self): + h, w = 12, 10 + self.image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8') + # gt_instances + data_sample = TextDetDataSample() + gt_instances_data = dict( + bboxes=self._rand_bboxes(5, h, w), + polygons=self._rand_polys(5, h, w), + labels=torch.zeros(5, ), + texts=['text1', 'text2', 'text3', 'text4', 'text5']) + gt_instances = InstanceData(**gt_instances_data) + data_sample.gt_instances = gt_instances + + pred_instances_data = dict( + bboxes=self._rand_bboxes(5, h, w), + labels=torch.zeros(5, ), + scores=torch.rand((5, )), + texts=['text1', 'text2', 'text3', 'text4', 'text5']) + pred_instances = InstanceData(**pred_instances_data) + data_sample.pred_instances = pred_instances + data_sample = data_sample.numpy() + self.data_sample = data_sample + + @staticmethod + def _rand_bboxes(num_boxes, h, w): + cx, cy, bw, bh = torch.rand(num_boxes, 4).T + + tl_x = ((cx * w) - (w * bw / 2)).clamp(0, w).unsqueeze(0) + tl_y = ((cy * h) - (h * bh / 2)).clamp(0, h).unsqueeze(0) + br_x = ((cx * w) + (w * bw / 2)).clamp(0, w).unsqueeze(0) + br_y = ((cy * h) + (h * bh / 2)).clamp(0, h).unsqueeze(0) + + bboxes = torch.cat([tl_x, tl_y, br_x, br_y], dim=0).T + + return bboxes + + def _rand_polys(self, num_bboxes, h, w): + bboxes = self._rand_bboxes(num_bboxes, h, w) + bboxes = bboxes.tolist() + polys = [bbox2poly(bbox) for bbox in bboxes] + return polys + + def test_add_datasample(self): + image = self.image + h, w, c = image.shape + + visualizer = TextSpottingLocalVisualizer() + visualizer.add_datasample('image', image, self.data_sample) + + with tempfile.TemporaryDirectory() as tmp_dir: + # test out + out_file = osp.join(tmp_dir, 'out_file.jpg') + visualizer.add_datasample( + 'image', + image, + self.data_sample, + out_file=out_file, + draw_gt=False, + draw_pred=False) + self._assert_image_and_shape(out_file, (h, w, c)) + + visualizer.add_datasample( + 'image', image, self.data_sample, out_file=out_file) + self._assert_image_and_shape(out_file, (h * 2, w * 2, c)) + + visualizer.add_datasample( + 'image', + image, + self.data_sample, + draw_gt=False, + out_file=out_file) + self._assert_image_and_shape(out_file, (h, w * 2, c)) + + visualizer.add_datasample( + 'image', + image, + self.data_sample, + draw_pred=False, + out_file=out_file) + self._assert_image_and_shape(out_file, (h, w * 2, c)) + bboxes = self.data_sample.pred_instances.pop('bboxes') + bboxes = bboxes.tolist() + polys = [bbox2poly(bbox) for bbox in bboxes] + self.data_sample.pred_instances.polygons = polys + visualizer.add_datasample( + 'image', + image, + self.data_sample, + draw_gt=False, + out_file=out_file) + self._assert_image_and_shape(out_file, (h, w * 2, c)) + + def _assert_image_and_shape(self, out_file, out_shape): + self.assertTrue(osp.exists(out_file)) + drawn_img = cv2.imread(out_file) + self.assertTrue(drawn_img.shape == out_shape)