[Refactor] TextDetLocalVisualizer

pull/1178/head
wangxinyu 2022-05-26 05:01:01 +00:00 committed by gaotongxiao
parent c78be99f6b
commit ee48713a89
3 changed files with 293 additions and 1 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .textdet_visualizer import TextDetLocalVisualizer
from .textrecog_visualizer import TextRecogLocalVisualizer
__all__ = ['TextRecogLocalVisualizer']
__all__ = ['TextDetLocalVisualizer', 'TextRecogLocalVisualizer']

View File

@ -0,0 +1,178 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
import mmcv
import numpy as np
from mmengine import Visualizer
from mmocr.core import TextDetDataSample
from mmocr.registry import VISUALIZERS
@VISUALIZERS.register_module()
class TextDetLocalVisualizer(Visualizer):
"""The MMOCR Text Detection Local Visualizer.
Args:
name (str): Name of the instance. Defaults to 'visualizer'.
image (np.ndarray, optional): The origin image to draw. The format
should be RGB. Defaults to None.
with_poly (bool): Whether to draw polygons. Defaults to True.
with_bbox (bool): Whether to draw bboxes. Defaults to False.
vis_backends (list, optional): Visual backend config list.
Defaults to None.
save_dir (str, optional): Save file dir for all storage backends.
If it is None, the backend storage will not save any data.
gt_color (Union[str, tuple, list[str], list[tuple]]): The
colors of GT polygons and bboxes. ``colors`` can have the same
length with lines or just single value. If ``colors`` is single
value, all the lines will have the same colors. Refer to
`matplotlib.colors` for full list of formats that are accepted.
Defaults to 'g'.
pred_color (Union[str, tuple, list[str], list[tuple]]): The
colors of pred polygons and bboxes. ``colors`` can have the same
length with lines or just single value. If ``colors`` is single
value, all the lines will have the same colors. Refer to
`matplotlib.colors` for full list of formats that are accepted.
Defaults to 'r'.
line_width (int, float): The linewidth of lines. Defaults to 2.
alpha (float): The transparency of bboxes or polygons. Defaults to 0.8.
"""
def __init__(self,
name: str = 'visualizer',
image: Optional[np.ndarray] = None,
with_poly: bool = True,
with_bbox: bool = False,
vis_backends: Optional[Dict] = None,
save_dir: Optional[str] = None,
gt_color: Union[str, Tuple, List[str], List[Tuple]] = 'g',
pred_color: Union[str, Tuple, List[str], List[Tuple]] = 'r',
line_width: Union[int, float] = 2,
alpha: float = 0.8) -> None:
super().__init__(
name=name,
image=image,
vis_backends=vis_backends,
save_dir=save_dir)
self.with_poly = with_poly
self.with_bbox = with_bbox
self.gt_color = gt_color
self.pred_color = pred_color
self.line_width = line_width
self.alpha = alpha
def add_datasample(self,
name: str,
image: np.ndarray,
gt_sample: Optional['TextDetDataSample'] = None,
pred_sample: Optional['TextDetDataSample'] = None,
draw_gt: bool = True,
draw_pred: bool = True,
show: bool = False,
wait_time: int = 0,
out_file: Optional[str] = None,
pred_score_thr: float = 0.3,
step: int = 0) -> None:
"""Draw datasample and save to all backends.
- If GT and prediction are plotted at the same time, they are
displayed in a stitched image where the left image is the
ground truth and the right image is the prediction.
- If ``show`` is True, all storage backends are ignored, and
the images will be displayed in a local window.
- If ``out_file`` is specified, the drawn image will be
saved to ``out_file``. This is usually used when the display
is not available.
Args:
name (str): The image identifier.
image (np.ndarray): The image to draw.
gt_sample (:obj:`TextDetDataSample`, optional): GT
TextDetDataSample. Defaults to None.
pred_sample (:obj:`TextDetDataSample`, optional): Predicted
TextDetDataSample. Defaults to None.
draw_gt (bool): Whether to draw GT TextDetDataSample.
Defaults to True.
draw_pred (bool): Whether to draw Predicted TextDetDataSample.
Defaults to True.
show (bool): Whether to display the drawn image. Default to False.
wait_time (float): The interval of show (s). Defaults to 0.
out_file (str): Path to output file. Defaults to None.
pred_score_thr (float): The threshold to visualize the bboxes
and masks. Defaults to 0.3.
step (int): Global step value to record. Defaults to 0.
"""
gt_img_data = None
pred_img_data = None
if draw_gt and gt_sample is not None and 'gt_instances' in gt_sample:
gt_instances = gt_sample.gt_instances
self.set_image(image)
if self.with_poly and 'polygons' in gt_instances:
gt_polygons = gt_instances.polygons
gt_polygons = [
gt_polygon.reshape(-1, 2) for gt_polygon in gt_polygons
]
self.draw_polygons(
gt_polygons,
alpha=self.alpha,
edge_colors=self.gt_color,
line_widths=self.line_width)
if self.with_bbox and 'bboxes' in gt_instances:
gt_bboxes = gt_instances.bboxes
self.draw_bboxes(
gt_bboxes,
alpha=self.alpha,
edge_colors=self.gt_color,
line_widths=self.line_width)
gt_img_data = self.get_image()
if draw_pred and pred_sample is not None \
and 'pred_instances' in pred_sample:
pred_instances = pred_sample.pred_instances
pred_instances = pred_instances[
pred_instances.scores > pred_score_thr].cpu()
self.set_image(image)
if self.with_poly and 'polygons' in pred_instances:
pred_polygons = pred_instances.polygons
pred_polygons = [
pred_polygon.reshape(-1, 2)
for pred_polygon in pred_polygons
]
self.draw_polygons(
pred_polygons,
alpha=self.alpha,
edge_colors=self.pred_color,
line_widths=self.line_width)
if self.with_bbox and 'bboxes' in pred_instances:
pred_bboxes = pred_instances.bboxes
self.draw_bboxes(
pred_bboxes,
alpha=self.alpha,
edge_colors=self.pred_color,
line_widths=self.line_width)
pred_img_data = self.get_image()
if gt_img_data is not None and pred_img_data is not None:
drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)
elif gt_img_data is not None:
drawn_img = gt_img_data
else:
drawn_img = pred_img_data
if show:
self.show(drawn_img, win_name=name, wait_time=wait_time)
else:
self.add_image(name, drawn_img, step)
if out_file is not None:
mmcv.imwrite(drawn_img[..., ::-1], out_file)

View File

@ -0,0 +1,113 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
import unittest
import cv2
import numpy as np
import torch
from mmengine.data import InstanceData
from mmocr.core import TextDetDataSample
from mmocr.core.visualization import TextDetLocalVisualizer
from mmocr.utils import bbox2poly
class TestTextDetLocalVisualizer(unittest.TestCase):
def setUp(self):
h, w = 12, 10
self.image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8')
# gt_instances
gt_det_data_sample = TextDetDataSample()
gt_instances_data = dict(
bboxes=self._rand_bboxes(5, h, w),
polygons=self._rand_polys(5, h, w),
labels=torch.zeros(5, ))
gt_instances = InstanceData(**gt_instances_data)
gt_det_data_sample.gt_instances = gt_instances
self.gt_det_data_sample = gt_det_data_sample
# pred_instances
pred_det_data_sample = TextDetDataSample()
pred_instances_data = dict(
bboxes=self._rand_bboxes(5, h, w),
polygons=self._rand_polys(5, h, w),
labels=torch.zeros(5, ),
scores=torch.rand((5, )))
pred_instances = InstanceData(**pred_instances_data)
pred_det_data_sample.pred_instances = pred_instances
self.pred_det_data_sample = pred_det_data_sample
def test_text_det_local_visualizer(self):
for with_poly in [True, False]:
for with_bbox in [True, False]:
vis_cfg = dict(with_poly=with_poly, with_bbox=with_bbox)
self._test_add_datasample(vis_cfg=vis_cfg)
@staticmethod
def _rand_bboxes(num_boxes, h, w):
cx, cy, bw, bh = torch.rand(num_boxes, 4).T
tl_x = ((cx * w) - (w * bw / 2)).clip(0, w)
tl_y = ((cy * h) - (h * bh / 2)).clip(0, h)
br_x = ((cx * w) + (w * bw / 2)).clip(0, w)
br_y = ((cy * h) + (h * bh / 2)).clip(0, h)
bboxes = torch.vstack([tl_x, tl_y, br_x, br_y]).T
return bboxes
def _rand_polys(self, num_bboxes, h, w):
bboxes = self._rand_bboxes(num_bboxes, h, w)
bboxes = bboxes.tolist()
polys = [bbox2poly(bbox) for bbox in bboxes]
return polys
def _test_add_datasample(self, vis_cfg):
image = self.image
h, w, c = image.shape
gt_det_data_sample = self.gt_det_data_sample
pred_det_data_sample = self.pred_det_data_sample
det_local_visualizer = TextDetLocalVisualizer(**vis_cfg)
det_local_visualizer.add_datasample('image', image, gt_det_data_sample)
with tempfile.TemporaryDirectory() as tmp_dir:
# test out
out_file = osp.join(tmp_dir, 'out_file.jpg')
det_local_visualizer.add_datasample(
'image', image, gt_det_data_sample, out_file=out_file)
self._assert_image_and_shape(out_file, (h, w, c))
det_local_visualizer.add_datasample(
'image',
image,
gt_det_data_sample,
pred_det_data_sample,
out_file=out_file)
self._assert_image_and_shape(out_file, (h, w * 2, c))
det_local_visualizer.add_datasample(
'image',
image,
gt_det_data_sample,
pred_det_data_sample,
draw_gt=False,
out_file=out_file)
self._assert_image_and_shape(out_file, (h, w, c))
det_local_visualizer.add_datasample(
'image',
image,
gt_det_data_sample,
pred_det_data_sample,
draw_pred=False,
out_file=out_file)
self._assert_image_and_shape(out_file, (h, w, c))
def _assert_image_and_shape(self, out_file, out_shape):
self.assertTrue(osp.exists(out_file))
drawn_img = cv2.imread(out_file)
self.assertTrue(drawn_img.shape == out_shape)