mirror of https://github.com/open-mmlab/mmocr.git
[Refactor] TextDetLocalVisualizer
parent
c78be99f6b
commit
ee48713a89
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .textdet_visualizer import TextDetLocalVisualizer
|
||||
from .textrecog_visualizer import TextRecogLocalVisualizer
|
||||
|
||||
__all__ = ['TextRecogLocalVisualizer']
|
||||
__all__ = ['TextDetLocalVisualizer', 'TextRecogLocalVisualizer']
|
||||
|
|
|
@ -0,0 +1,178 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmengine import Visualizer
|
||||
|
||||
from mmocr.core import TextDetDataSample
|
||||
from mmocr.registry import VISUALIZERS
|
||||
|
||||
|
||||
@VISUALIZERS.register_module()
|
||||
class TextDetLocalVisualizer(Visualizer):
|
||||
"""The MMOCR Text Detection Local Visualizer.
|
||||
|
||||
Args:
|
||||
name (str): Name of the instance. Defaults to 'visualizer'.
|
||||
image (np.ndarray, optional): The origin image to draw. The format
|
||||
should be RGB. Defaults to None.
|
||||
with_poly (bool): Whether to draw polygons. Defaults to True.
|
||||
with_bbox (bool): Whether to draw bboxes. Defaults to False.
|
||||
vis_backends (list, optional): Visual backend config list.
|
||||
Defaults to None.
|
||||
save_dir (str, optional): Save file dir for all storage backends.
|
||||
If it is None, the backend storage will not save any data.
|
||||
gt_color (Union[str, tuple, list[str], list[tuple]]): The
|
||||
colors of GT polygons and bboxes. ``colors`` can have the same
|
||||
length with lines or just single value. If ``colors`` is single
|
||||
value, all the lines will have the same colors. Refer to
|
||||
`matplotlib.colors` for full list of formats that are accepted.
|
||||
Defaults to 'g'.
|
||||
pred_color (Union[str, tuple, list[str], list[tuple]]): The
|
||||
colors of pred polygons and bboxes. ``colors`` can have the same
|
||||
length with lines or just single value. If ``colors`` is single
|
||||
value, all the lines will have the same colors. Refer to
|
||||
`matplotlib.colors` for full list of formats that are accepted.
|
||||
Defaults to 'r'.
|
||||
line_width (int, float): The linewidth of lines. Defaults to 2.
|
||||
alpha (float): The transparency of bboxes or polygons. Defaults to 0.8.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name: str = 'visualizer',
|
||||
image: Optional[np.ndarray] = None,
|
||||
with_poly: bool = True,
|
||||
with_bbox: bool = False,
|
||||
vis_backends: Optional[Dict] = None,
|
||||
save_dir: Optional[str] = None,
|
||||
gt_color: Union[str, Tuple, List[str], List[Tuple]] = 'g',
|
||||
pred_color: Union[str, Tuple, List[str], List[Tuple]] = 'r',
|
||||
line_width: Union[int, float] = 2,
|
||||
alpha: float = 0.8) -> None:
|
||||
super().__init__(
|
||||
name=name,
|
||||
image=image,
|
||||
vis_backends=vis_backends,
|
||||
save_dir=save_dir)
|
||||
self.with_poly = with_poly
|
||||
self.with_bbox = with_bbox
|
||||
self.gt_color = gt_color
|
||||
self.pred_color = pred_color
|
||||
self.line_width = line_width
|
||||
self.alpha = alpha
|
||||
|
||||
def add_datasample(self,
|
||||
name: str,
|
||||
image: np.ndarray,
|
||||
gt_sample: Optional['TextDetDataSample'] = None,
|
||||
pred_sample: Optional['TextDetDataSample'] = None,
|
||||
draw_gt: bool = True,
|
||||
draw_pred: bool = True,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
out_file: Optional[str] = None,
|
||||
pred_score_thr: float = 0.3,
|
||||
step: int = 0) -> None:
|
||||
"""Draw datasample and save to all backends.
|
||||
|
||||
- If GT and prediction are plotted at the same time, they are
|
||||
displayed in a stitched image where the left image is the
|
||||
ground truth and the right image is the prediction.
|
||||
- If ``show`` is True, all storage backends are ignored, and
|
||||
the images will be displayed in a local window.
|
||||
- If ``out_file`` is specified, the drawn image will be
|
||||
saved to ``out_file``. This is usually used when the display
|
||||
is not available.
|
||||
|
||||
Args:
|
||||
name (str): The image identifier.
|
||||
image (np.ndarray): The image to draw.
|
||||
gt_sample (:obj:`TextDetDataSample`, optional): GT
|
||||
TextDetDataSample. Defaults to None.
|
||||
pred_sample (:obj:`TextDetDataSample`, optional): Predicted
|
||||
TextDetDataSample. Defaults to None.
|
||||
draw_gt (bool): Whether to draw GT TextDetDataSample.
|
||||
Defaults to True.
|
||||
draw_pred (bool): Whether to draw Predicted TextDetDataSample.
|
||||
Defaults to True.
|
||||
show (bool): Whether to display the drawn image. Default to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
out_file (str): Path to output file. Defaults to None.
|
||||
pred_score_thr (float): The threshold to visualize the bboxes
|
||||
and masks. Defaults to 0.3.
|
||||
step (int): Global step value to record. Defaults to 0.
|
||||
"""
|
||||
gt_img_data = None
|
||||
pred_img_data = None
|
||||
|
||||
if draw_gt and gt_sample is not None and 'gt_instances' in gt_sample:
|
||||
gt_instances = gt_sample.gt_instances
|
||||
|
||||
self.set_image(image)
|
||||
|
||||
if self.with_poly and 'polygons' in gt_instances:
|
||||
gt_polygons = gt_instances.polygons
|
||||
gt_polygons = [
|
||||
gt_polygon.reshape(-1, 2) for gt_polygon in gt_polygons
|
||||
]
|
||||
self.draw_polygons(
|
||||
gt_polygons,
|
||||
alpha=self.alpha,
|
||||
edge_colors=self.gt_color,
|
||||
line_widths=self.line_width)
|
||||
|
||||
if self.with_bbox and 'bboxes' in gt_instances:
|
||||
gt_bboxes = gt_instances.bboxes
|
||||
self.draw_bboxes(
|
||||
gt_bboxes,
|
||||
alpha=self.alpha,
|
||||
edge_colors=self.gt_color,
|
||||
line_widths=self.line_width)
|
||||
|
||||
gt_img_data = self.get_image()
|
||||
|
||||
if draw_pred and pred_sample is not None \
|
||||
and 'pred_instances' in pred_sample:
|
||||
pred_instances = pred_sample.pred_instances
|
||||
pred_instances = pred_instances[
|
||||
pred_instances.scores > pred_score_thr].cpu()
|
||||
|
||||
self.set_image(image)
|
||||
|
||||
if self.with_poly and 'polygons' in pred_instances:
|
||||
pred_polygons = pred_instances.polygons
|
||||
pred_polygons = [
|
||||
pred_polygon.reshape(-1, 2)
|
||||
for pred_polygon in pred_polygons
|
||||
]
|
||||
self.draw_polygons(
|
||||
pred_polygons,
|
||||
alpha=self.alpha,
|
||||
edge_colors=self.pred_color,
|
||||
line_widths=self.line_width)
|
||||
|
||||
if self.with_bbox and 'bboxes' in pred_instances:
|
||||
pred_bboxes = pred_instances.bboxes
|
||||
self.draw_bboxes(
|
||||
pred_bboxes,
|
||||
alpha=self.alpha,
|
||||
edge_colors=self.pred_color,
|
||||
line_widths=self.line_width)
|
||||
|
||||
pred_img_data = self.get_image()
|
||||
|
||||
if gt_img_data is not None and pred_img_data is not None:
|
||||
drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)
|
||||
elif gt_img_data is not None:
|
||||
drawn_img = gt_img_data
|
||||
else:
|
||||
drawn_img = pred_img_data
|
||||
|
||||
if show:
|
||||
self.show(drawn_img, win_name=name, wait_time=wait_time)
|
||||
else:
|
||||
self.add_image(name, drawn_img, step)
|
||||
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(drawn_img[..., ::-1], out_file)
|
|
@ -0,0 +1,113 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.data import InstanceData
|
||||
|
||||
from mmocr.core import TextDetDataSample
|
||||
from mmocr.core.visualization import TextDetLocalVisualizer
|
||||
from mmocr.utils import bbox2poly
|
||||
|
||||
|
||||
class TestTextDetLocalVisualizer(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
h, w = 12, 10
|
||||
self.image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8')
|
||||
|
||||
# gt_instances
|
||||
gt_det_data_sample = TextDetDataSample()
|
||||
gt_instances_data = dict(
|
||||
bboxes=self._rand_bboxes(5, h, w),
|
||||
polygons=self._rand_polys(5, h, w),
|
||||
labels=torch.zeros(5, ))
|
||||
gt_instances = InstanceData(**gt_instances_data)
|
||||
gt_det_data_sample.gt_instances = gt_instances
|
||||
self.gt_det_data_sample = gt_det_data_sample
|
||||
|
||||
# pred_instances
|
||||
pred_det_data_sample = TextDetDataSample()
|
||||
pred_instances_data = dict(
|
||||
bboxes=self._rand_bboxes(5, h, w),
|
||||
polygons=self._rand_polys(5, h, w),
|
||||
labels=torch.zeros(5, ),
|
||||
scores=torch.rand((5, )))
|
||||
pred_instances = InstanceData(**pred_instances_data)
|
||||
pred_det_data_sample.pred_instances = pred_instances
|
||||
self.pred_det_data_sample = pred_det_data_sample
|
||||
|
||||
def test_text_det_local_visualizer(self):
|
||||
for with_poly in [True, False]:
|
||||
for with_bbox in [True, False]:
|
||||
vis_cfg = dict(with_poly=with_poly, with_bbox=with_bbox)
|
||||
self._test_add_datasample(vis_cfg=vis_cfg)
|
||||
|
||||
@staticmethod
|
||||
def _rand_bboxes(num_boxes, h, w):
|
||||
cx, cy, bw, bh = torch.rand(num_boxes, 4).T
|
||||
|
||||
tl_x = ((cx * w) - (w * bw / 2)).clip(0, w)
|
||||
tl_y = ((cy * h) - (h * bh / 2)).clip(0, h)
|
||||
br_x = ((cx * w) + (w * bw / 2)).clip(0, w)
|
||||
br_y = ((cy * h) + (h * bh / 2)).clip(0, h)
|
||||
|
||||
bboxes = torch.vstack([tl_x, tl_y, br_x, br_y]).T
|
||||
|
||||
return bboxes
|
||||
|
||||
def _rand_polys(self, num_bboxes, h, w):
|
||||
bboxes = self._rand_bboxes(num_bboxes, h, w)
|
||||
bboxes = bboxes.tolist()
|
||||
polys = [bbox2poly(bbox) for bbox in bboxes]
|
||||
return polys
|
||||
|
||||
def _test_add_datasample(self, vis_cfg):
|
||||
image = self.image
|
||||
h, w, c = image.shape
|
||||
gt_det_data_sample = self.gt_det_data_sample
|
||||
pred_det_data_sample = self.pred_det_data_sample
|
||||
|
||||
det_local_visualizer = TextDetLocalVisualizer(**vis_cfg)
|
||||
det_local_visualizer.add_datasample('image', image, gt_det_data_sample)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# test out
|
||||
out_file = osp.join(tmp_dir, 'out_file.jpg')
|
||||
det_local_visualizer.add_datasample(
|
||||
'image', image, gt_det_data_sample, out_file=out_file)
|
||||
self._assert_image_and_shape(out_file, (h, w, c))
|
||||
|
||||
det_local_visualizer.add_datasample(
|
||||
'image',
|
||||
image,
|
||||
gt_det_data_sample,
|
||||
pred_det_data_sample,
|
||||
out_file=out_file)
|
||||
self._assert_image_and_shape(out_file, (h, w * 2, c))
|
||||
|
||||
det_local_visualizer.add_datasample(
|
||||
'image',
|
||||
image,
|
||||
gt_det_data_sample,
|
||||
pred_det_data_sample,
|
||||
draw_gt=False,
|
||||
out_file=out_file)
|
||||
self._assert_image_and_shape(out_file, (h, w, c))
|
||||
|
||||
det_local_visualizer.add_datasample(
|
||||
'image',
|
||||
image,
|
||||
gt_det_data_sample,
|
||||
pred_det_data_sample,
|
||||
draw_pred=False,
|
||||
out_file=out_file)
|
||||
self._assert_image_and_shape(out_file, (h, w, c))
|
||||
|
||||
def _assert_image_and_shape(self, out_file, out_shape):
|
||||
self.assertTrue(osp.exists(out_file))
|
||||
drawn_img = cv2.imread(out_file)
|
||||
self.assertTrue(drawn_img.shape == out_shape)
|
Loading…
Reference in New Issue