mirror of https://github.com/open-mmlab/mmocr.git
parent
b26907e908
commit
dfc17207ba
|
@ -1,10 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_visualizer import BaseLocalVisualizer
|
||||
from .kie_visualizer import KIELocalVisualizer
|
||||
from .textdet_visualizer import TextDetLocalVisualizer
|
||||
from .textrecog_visualizer import TextRecogLocalVisualizer
|
||||
from .textspotting_visualizer import TextSpottingLocalVisualizer
|
||||
|
||||
__all__ = [
|
||||
'KIELocalVisualizer', 'TextDetLocalVisualizer', 'TextRecogLocalVisualizer',
|
||||
'TextSpottingLocalVisualizer'
|
||||
'BaseLocalVisualizer', 'KIELocalVisualizer', 'TextDetLocalVisualizer',
|
||||
'TextRecogLocalVisualizer', 'TextSpottingLocalVisualizer'
|
||||
]
|
||||
|
|
|
@ -50,14 +50,13 @@ class BaseLocalVisualizer(Visualizer):
|
|||
(95, 54, 80), (128, 76, 255), (201, 57, 1), (246, 0, 122),
|
||||
(191, 162, 208)]
|
||||
|
||||
@staticmethod
|
||||
def _draw_labels(visualizer: Visualizer,
|
||||
image: np.ndarray,
|
||||
labels: Union[np.ndarray, torch.Tensor],
|
||||
bboxes: Union[np.ndarray, torch.Tensor],
|
||||
colors: Union[str, Sequence[str]] = 'k',
|
||||
font_size: Union[int, float] = 10,
|
||||
auto_font_size: bool = False) -> np.ndarray:
|
||||
def get_labels_image(self,
|
||||
image: np.ndarray,
|
||||
labels: Union[np.ndarray, torch.Tensor],
|
||||
bboxes: Union[np.ndarray, torch.Tensor],
|
||||
colors: Union[str, Sequence[str]] = 'k',
|
||||
font_size: Union[int, float] = 10,
|
||||
auto_font_size: bool = False) -> np.ndarray:
|
||||
"""Draw labels on image.
|
||||
|
||||
Args:
|
||||
|
@ -75,7 +74,7 @@ class BaseLocalVisualizer(Visualizer):
|
|||
auto_font_size (bool): Whether to automatically adjust font size.
|
||||
Defaults to False.
|
||||
"""
|
||||
if colors is not None and isinstance(colors, Sequence):
|
||||
if colors is not None and isinstance(colors, (list, tuple)):
|
||||
size = math.ceil(len(labels) / len(colors))
|
||||
colors = (colors * size)[:len(labels)]
|
||||
if auto_font_size:
|
||||
|
@ -83,68 +82,124 @@ class BaseLocalVisualizer(Visualizer):
|
|||
font_size, (int, float))
|
||||
font_size = (bboxes[:, 2:] - bboxes[:, :2]).min(-1) * font_size
|
||||
font_size = font_size.tolist()
|
||||
visualizer.set_image(image)
|
||||
visualizer.draw_texts(
|
||||
self.set_image(image)
|
||||
self.draw_texts(
|
||||
labels, (bboxes[:, :2] + bboxes[:, 2:]) / 2,
|
||||
vertical_alignments='center',
|
||||
horizontal_alignments='center',
|
||||
colors='k',
|
||||
font_sizes=font_size)
|
||||
return visualizer.get_image()
|
||||
return self.get_image()
|
||||
|
||||
@staticmethod
|
||||
def _draw_polygons(visualizer: Visualizer,
|
||||
image: np.ndarray,
|
||||
polygons: Sequence[np.ndarray],
|
||||
colors: Union[str, Sequence[str]] = 'g',
|
||||
filling: bool = False,
|
||||
line_width: Union[int, float] = 0.5,
|
||||
alpha: float = 0.5) -> np.ndarray:
|
||||
if colors is not None and isinstance(colors, Sequence):
|
||||
def get_polygons_image(self,
|
||||
image: np.ndarray,
|
||||
polygons: Sequence[np.ndarray],
|
||||
colors: Union[str, Sequence[str]] = 'g',
|
||||
filling: bool = False,
|
||||
line_width: Union[int, float] = 0.5,
|
||||
alpha: float = 0.5) -> np.ndarray:
|
||||
"""Draw polygons on image.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The origin image to draw. The format
|
||||
should be RGB.
|
||||
polygons (Sequence[np.ndarray]): The polygons to draw. The shape
|
||||
should be (N, 2).
|
||||
colors (Union[str, Sequence[str]]): The colors of polygons.
|
||||
``colors`` can have the same length with polygons or just
|
||||
single value. If ``colors`` is single value, all the polygons
|
||||
will have the same colors. Refer to `matplotlib.colors` for
|
||||
full list of formats that are accepted. Defaults to 'g'.
|
||||
filling (bool): Whether to fill the polygons. Defaults to False.
|
||||
line_width (Union[int, float]): The line width of polygons.
|
||||
Defaults to 0.5.
|
||||
alpha (float): The alpha of polygons. Defaults to 0.5.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image with polygons drawn.
|
||||
"""
|
||||
if colors is not None and isinstance(colors, (list, tuple)):
|
||||
size = math.ceil(len(polygons) / len(colors))
|
||||
colors = (colors * size)[:len(polygons)]
|
||||
visualizer.set_image(image)
|
||||
self.set_image(image)
|
||||
if filling:
|
||||
visualizer.draw_polygons(
|
||||
self.draw_polygons(
|
||||
polygons,
|
||||
face_colors=colors,
|
||||
edge_colors=colors,
|
||||
line_widths=line_width,
|
||||
alpha=alpha)
|
||||
else:
|
||||
visualizer.draw_polygons(
|
||||
self.draw_polygons(
|
||||
polygons,
|
||||
edge_colors=colors,
|
||||
line_widths=line_width,
|
||||
alpha=alpha)
|
||||
return visualizer.get_image()
|
||||
return self.get_image()
|
||||
|
||||
@staticmethod
|
||||
def _draw_bboxes(visualizer: Visualizer,
|
||||
image: np.ndarray,
|
||||
bboxes: Union[np.ndarray, torch.Tensor],
|
||||
colors: Union[str, Sequence[str]] = 'g',
|
||||
filling: bool = False,
|
||||
line_width: Union[int, float] = 0.5,
|
||||
alpha: float = 0.5) -> np.ndarray:
|
||||
if colors is not None and isinstance(colors, Sequence):
|
||||
def get_bboxes_image(self: Visualizer,
|
||||
image: np.ndarray,
|
||||
bboxes: Union[np.ndarray, torch.Tensor],
|
||||
colors: Union[str, Sequence[str]] = 'g',
|
||||
filling: bool = False,
|
||||
line_width: Union[int, float] = 0.5,
|
||||
alpha: float = 0.5) -> np.ndarray:
|
||||
"""Draw bboxes on image.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The origin image to draw. The format
|
||||
should be RGB.
|
||||
bboxes (Union[np.ndarray, torch.Tensor]): The bboxes to draw.
|
||||
colors (Union[str, Sequence[str]]): The colors of bboxes.
|
||||
``colors`` can have the same length with bboxes or just single
|
||||
value. If ``colors`` is single value, all the bboxes will have
|
||||
the same colors. Refer to `matplotlib.colors` for full list of
|
||||
formats that are accepted. Defaults to 'g'.
|
||||
filling (bool): Whether to fill the bboxes. Defaults to False.
|
||||
line_width (Union[int, float]): The line width of bboxes.
|
||||
Defaults to 0.5.
|
||||
alpha (float): The alpha of bboxes. Defaults to 0.5.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image with bboxes drawn.
|
||||
"""
|
||||
if colors is not None and isinstance(colors, (list, tuple)):
|
||||
size = math.ceil(len(bboxes) / len(colors))
|
||||
colors = (colors * size)[:len(bboxes)]
|
||||
visualizer.set_image(image)
|
||||
self.set_image(image)
|
||||
if filling:
|
||||
visualizer.draw_bboxes(
|
||||
self.draw_bboxes(
|
||||
bboxes,
|
||||
face_colors=colors,
|
||||
edge_colors=colors,
|
||||
line_widths=line_width,
|
||||
alpha=alpha)
|
||||
else:
|
||||
visualizer.draw_bboxes(
|
||||
self.draw_bboxes(
|
||||
bboxes,
|
||||
edge_colors=colors,
|
||||
line_widths=line_width,
|
||||
alpha=alpha)
|
||||
return visualizer.get_image()
|
||||
return self.get_image()
|
||||
|
||||
def _draw_instances(self) -> np.ndarray:
|
||||
raise NotImplementedError
|
||||
|
||||
def _cat_image(self, imgs: Sequence[np.ndarray], axis: int) -> np.ndarray:
|
||||
"""Concatenate images.
|
||||
|
||||
Args:
|
||||
imgs (Sequence[np.ndarray]): The images to concatenate.
|
||||
axis (int): The axis to concatenate.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The concatenated image.
|
||||
"""
|
||||
cat_image = list()
|
||||
for img in imgs:
|
||||
if img is not None:
|
||||
cat_image.append(img)
|
||||
if len(cat_image):
|
||||
return np.concatenate(cat_image, axis=axis)
|
||||
else:
|
||||
return None
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
|
@ -15,31 +14,11 @@ from mmengine.visualization.utils import (check_type, check_type_and_length,
|
|||
|
||||
from mmocr.registry import VISUALIZERS
|
||||
from mmocr.structures import KIEDataSample
|
||||
|
||||
PALETTE = [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230),
|
||||
(106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192),
|
||||
(250, 170, 30), (100, 170, 30), (220, 220, 0), (175, 116, 175),
|
||||
(250, 0, 30), (165, 42, 42), (255, 77, 255), (0, 226, 252),
|
||||
(182, 182, 255), (0, 82, 0), (120, 166, 157), (110, 76, 0),
|
||||
(174, 57, 255), (199, 100, 0), (72, 0, 118), (255, 179, 240),
|
||||
(0, 125, 92), (209, 0, 151), (188, 208, 182), (0, 220, 176),
|
||||
(255, 99, 164), (92, 0, 73), (133, 129, 255), (78, 180, 255),
|
||||
(0, 228, 0), (174, 255, 243), (45, 89, 255), (134, 134, 103),
|
||||
(145, 148, 174), (255, 208, 186), (197, 226, 255), (171, 134, 1),
|
||||
(109, 63, 54), (207, 138, 255), (151, 0, 95), (9, 80, 61),
|
||||
(84, 105, 51), (74, 65, 105), (166, 196, 102), (208, 195, 210),
|
||||
(255, 109, 65), (0, 143, 149), (179, 0, 194), (209, 99, 106),
|
||||
(5, 121, 0), (227, 255, 205), (147, 186, 208), (153, 69, 1),
|
||||
(3, 95, 161), (163, 255, 0), (119, 0, 170), (0, 182, 199),
|
||||
(0, 165, 120), (183, 130, 88), (95, 32, 0), (130, 114, 135),
|
||||
(110, 129, 133), (166, 74, 118), (219, 142, 185), (79, 210, 114),
|
||||
(178, 90, 62), (65, 70, 15), (127, 167, 115), (59, 105, 106),
|
||||
(142, 108, 45), (196, 172, 0), (95, 54, 80), (128, 76, 255),
|
||||
(201, 57, 1), (246, 0, 122), (191, 162, 208)]
|
||||
from .base_visualizer import BaseLocalVisualizer
|
||||
|
||||
|
||||
@VISUALIZERS.register_module()
|
||||
class KIELocalVisualizer(Visualizer):
|
||||
class KIELocalVisualizer(BaseLocalVisualizer):
|
||||
"""The MMOCR Text Detection Local Visualizer.
|
||||
|
||||
Args:
|
||||
|
@ -65,102 +44,6 @@ class KIELocalVisualizer(Visualizer):
|
|||
super().__init__(name=name, **kwargs)
|
||||
self.is_openset = is_openset
|
||||
|
||||
@staticmethod
|
||||
def _draw_labels(visualizer: Visualizer,
|
||||
image: np.ndarray,
|
||||
labels: Union[np.ndarray, torch.Tensor],
|
||||
bboxes: Union[np.ndarray, torch.Tensor],
|
||||
colors: Union[str, Sequence[str]] = 'k',
|
||||
font_size: Union[int, float] = 10,
|
||||
auto_font_size: bool = False) -> np.ndarray:
|
||||
"""Draw labels on image.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The origin image to draw. The format
|
||||
should be RGB.
|
||||
labels (Union[np.ndarray, torch.Tensor]): The labels to draw.
|
||||
bboxes (Union[np.ndarray, torch.Tensor]): The bboxes to draw.
|
||||
colors (Union[str, Sequence[str]]): The colors of labels.
|
||||
``colors`` can have the same length with labels or just single
|
||||
value. If ``colors`` is single value, all the labels will have
|
||||
the same colors. Refer to `matplotlib.colors` for full list of
|
||||
formats that are accepted. Defaults to 'k'.
|
||||
font_size (Union[int, float]): The font size of labels. Defaults
|
||||
to 10.
|
||||
auto_font_size (bool): Whether to automatically adjust font size.
|
||||
Defaults to False.
|
||||
"""
|
||||
if colors is not None and isinstance(colors, Sequence):
|
||||
size = math.ceil(len(labels) / len(colors))
|
||||
colors = (colors * size)[:len(labels)]
|
||||
if auto_font_size:
|
||||
assert font_size is not None and isinstance(
|
||||
font_size, (int, float))
|
||||
font_size = (bboxes[:, 2:] - bboxes[:, :2]).min(-1) * font_size
|
||||
font_size = font_size.tolist()
|
||||
visualizer.set_image(image)
|
||||
visualizer.draw_texts(
|
||||
labels, (bboxes[:, :2] + bboxes[:, 2:]) / 2,
|
||||
vertical_alignments='center',
|
||||
horizontal_alignments='center',
|
||||
colors='k',
|
||||
font_sizes=font_size)
|
||||
return visualizer.get_image()
|
||||
|
||||
@staticmethod
|
||||
def _draw_polygons(visualizer: Visualizer,
|
||||
image: np.ndarray,
|
||||
polygons: Sequence[np.ndarray],
|
||||
colors: Union[str, Sequence[str]] = 'g',
|
||||
filling: bool = False,
|
||||
line_width: Union[int, float] = 0.5,
|
||||
alpha: float = 0.5) -> np.ndarray:
|
||||
if colors is not None and isinstance(colors, Sequence):
|
||||
size = math.ceil(len(polygons) / len(colors))
|
||||
colors = (colors * size)[:len(polygons)]
|
||||
visualizer.set_image(image)
|
||||
if filling:
|
||||
visualizer.draw_polygons(
|
||||
polygons,
|
||||
face_colors=colors,
|
||||
edge_colors=colors,
|
||||
line_widths=line_width,
|
||||
alpha=alpha)
|
||||
else:
|
||||
visualizer.draw_polygons(
|
||||
polygons,
|
||||
edge_colors=colors,
|
||||
line_widths=line_width,
|
||||
alpha=alpha)
|
||||
return visualizer.get_image()
|
||||
|
||||
@staticmethod
|
||||
def _draw_bboxes(visualizer: Visualizer,
|
||||
image: np.ndarray,
|
||||
bboxes: Union[np.ndarray, torch.Tensor],
|
||||
colors: Union[str, Sequence[str]] = 'g',
|
||||
filling: bool = False,
|
||||
line_width: Union[int, float] = 0.5,
|
||||
alpha: float = 0.5) -> np.ndarray:
|
||||
if colors is not None and isinstance(colors, Sequence):
|
||||
size = math.ceil(len(bboxes) / len(colors))
|
||||
colors = (colors * size)[:len(bboxes)]
|
||||
visualizer.set_image(image)
|
||||
if filling:
|
||||
visualizer.draw_bboxes(
|
||||
bboxes,
|
||||
face_colors=colors,
|
||||
edge_colors=colors,
|
||||
line_widths=line_width,
|
||||
alpha=alpha)
|
||||
else:
|
||||
visualizer.draw_bboxes(
|
||||
bboxes,
|
||||
edge_colors=colors,
|
||||
line_widths=line_width,
|
||||
alpha=alpha)
|
||||
return visualizer.get_image()
|
||||
|
||||
def _draw_edge_label(self,
|
||||
image: np.ndarray,
|
||||
edge_labels: Union[np.ndarray, torch.Tensor],
|
||||
|
@ -182,6 +65,9 @@ class KIELocalVisualizer(Visualizer):
|
|||
arrow_colors (str, optional): The colors of arrows. Refer to
|
||||
`matplotlib.colors` for full list of formats that are accepted.
|
||||
Defaults to 'g'.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image with edge labels drawn.
|
||||
"""
|
||||
pairs = np.where(edge_labels > 0)
|
||||
key_bboxes = bboxes[pairs[0]]
|
||||
|
@ -253,49 +139,45 @@ class KIELocalVisualizer(Visualizer):
|
|||
class_names (dict): The class names for bbox labels.
|
||||
is_openset (bool): Whether the dataset is openset. Defaults to
|
||||
False.
|
||||
arrow_colors (str, optional): The colors of arrows. Refer to
|
||||
`matplotlib.colors` for full list of formats that are accepted.
|
||||
Defaults to 'g'.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image with instances drawn.
|
||||
"""
|
||||
img_shape = image.shape[:2]
|
||||
empty_shape = (img_shape[0], img_shape[1], 3)
|
||||
|
||||
if polygons:
|
||||
polygons = [polygon.reshape(-1, 2) for polygon in polygons]
|
||||
if polygons:
|
||||
image = self._draw_polygons(
|
||||
self, image, polygons, filling=True, colors=PALETTE)
|
||||
else:
|
||||
image = self._draw_bboxes(
|
||||
self, image, bboxes, filling=True, colors=PALETTE)
|
||||
|
||||
text_image = np.full(empty_shape, 255, dtype=np.uint8)
|
||||
text_image = self._draw_labels(self, text_image, texts, bboxes)
|
||||
if polygons:
|
||||
text_image = self._draw_polygons(
|
||||
self, text_image, polygons, colors=PALETTE)
|
||||
else:
|
||||
text_image = self._draw_bboxes(
|
||||
self, text_image, bboxes, colors=PALETTE)
|
||||
text_image = self.get_labels_image(text_image, texts, bboxes)
|
||||
|
||||
classes_image = np.full(empty_shape, 255, dtype=np.uint8)
|
||||
bbox_classes = [class_names[int(i)]['name'] for i in bbox_labels]
|
||||
classes_image = self._draw_labels(self, classes_image, bbox_classes,
|
||||
bboxes)
|
||||
classes_image = self.get_labels_image(classes_image, bbox_classes,
|
||||
bboxes)
|
||||
if polygons:
|
||||
classes_image = self._draw_polygons(
|
||||
self, classes_image, polygons, colors=PALETTE)
|
||||
polygons = [polygon.reshape(-1, 2) for polygon in polygons]
|
||||
image = self.get_polygons_image(
|
||||
image, polygons, filling=True, colors=self.PALETTE)
|
||||
text_image = self.get_polygons_image(
|
||||
text_image, polygons, colors=self.PALETTE)
|
||||
classes_image = self.get_polygons_image(
|
||||
classes_image, polygons, colors=self.PALETTE)
|
||||
else:
|
||||
classes_image = self._draw_bboxes(
|
||||
self, classes_image, bboxes, colors=PALETTE)
|
||||
|
||||
edge_image = None
|
||||
image = self.get_bboxes_image(
|
||||
image, bboxes, filling=True, colors=self.PALETTE)
|
||||
text_image = self.get_bboxes_image(
|
||||
text_image, bboxes, colors=self.PALETTE)
|
||||
classes_image = self.get_bboxes_image(
|
||||
classes_image, bboxes, colors=self.PALETTE)
|
||||
cat_image = [image, text_image, classes_image]
|
||||
if is_openset:
|
||||
edge_image = np.full(empty_shape, 255, dtype=np.uint8)
|
||||
edge_image = self._draw_edge_label(edge_image, edge_labels, bboxes,
|
||||
texts, arrow_colors)
|
||||
cat_image = []
|
||||
for i in [image, text_image, classes_image, edge_image]:
|
||||
if i is not None:
|
||||
cat_image.append(i)
|
||||
return np.concatenate(cat_image, axis=1)
|
||||
cat_image.append(edge_image)
|
||||
return self._cat_image(cat_image, axis=1)
|
||||
|
||||
def add_datasample(self,
|
||||
name: str,
|
||||
|
@ -336,8 +218,7 @@ class KIELocalVisualizer(Visualizer):
|
|||
out_file (str): Path to output file. Defaults to None.
|
||||
step (int): Global step value to record. Defaults to 0.
|
||||
"""
|
||||
gt_img_data = None
|
||||
pred_img_data = None
|
||||
cat_images = list()
|
||||
|
||||
if draw_gt:
|
||||
gt_bboxes = data_sample.gt_instances.bboxes
|
||||
|
@ -350,6 +231,7 @@ class KIELocalVisualizer(Visualizer):
|
|||
gt_texts,
|
||||
self.dataset_meta['category'],
|
||||
self.is_openset, 'g')
|
||||
cat_images.append(gt_img_data)
|
||||
if draw_pred:
|
||||
gt_bboxes = data_sample.gt_instances.bboxes
|
||||
pred_labels = data_sample.pred_instances.labels
|
||||
|
@ -362,22 +244,19 @@ class KIELocalVisualizer(Visualizer):
|
|||
gt_texts,
|
||||
self.dataset_meta['category'],
|
||||
self.is_openset, 'r')
|
||||
if gt_img_data is not None and pred_img_data is not None:
|
||||
drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=0)
|
||||
elif gt_img_data is not None:
|
||||
drawn_img = gt_img_data
|
||||
elif pred_img_data is not None:
|
||||
drawn_img = pred_img_data
|
||||
else:
|
||||
drawn_img = image
|
||||
cat_images.append(pred_img_data)
|
||||
|
||||
cat_images = self._cat_image(cat_images, axis=0)
|
||||
if cat_images is None:
|
||||
cat_images = image
|
||||
|
||||
if show:
|
||||
self.show(drawn_img, win_name=name, wait_time=wait_time)
|
||||
self.show(cat_images, win_name=name, wait_time=wait_time)
|
||||
else:
|
||||
self.add_image(name, drawn_img, step)
|
||||
self.add_image(name, cat_images, step)
|
||||
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(drawn_img[..., ::-1], out_file)
|
||||
mmcv.imwrite(cat_images[..., ::-1], out_file)
|
||||
|
||||
def draw_arrows(self,
|
||||
x_data: Union[np.ndarray, torch.Tensor],
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmengine.visualization import Visualizer
|
||||
import torch
|
||||
|
||||
from mmocr.registry import VISUALIZERS
|
||||
from mmocr.structures import TextDetDataSample
|
||||
from .base_visualizer import BaseLocalVisualizer
|
||||
|
||||
|
||||
@VISUALIZERS.register_module()
|
||||
class TextDetLocalVisualizer(Visualizer):
|
||||
class TextDetLocalVisualizer(BaseLocalVisualizer):
|
||||
"""The MMOCR Text Detection Local Visualizer.
|
||||
|
||||
Args:
|
||||
|
@ -62,6 +63,42 @@ class TextDetLocalVisualizer(Visualizer):
|
|||
self.line_width = line_width
|
||||
self.alpha = alpha
|
||||
|
||||
def _draw_instances(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: Union[np.ndarray, torch.Tensor],
|
||||
polygons: Sequence[np.ndarray],
|
||||
color: Union[str, Tuple, List[str], List[Tuple]] = 'g',
|
||||
) -> np.ndarray:
|
||||
"""Draw bboxes and polygons on image.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The origin image to draw.
|
||||
bboxes (Union[np.ndarray, torch.Tensor]): The bboxes to draw.
|
||||
polygons (Sequence[np.ndarray]): The polygons to draw.
|
||||
color (Union[str, tuple, list[str], list[tuple]]): The
|
||||
colors of polygons and bboxes. ``colors`` can have the same
|
||||
length with lines or just single value. If ``colors`` is
|
||||
single value, all the lines will have the same colors. Refer
|
||||
to `matplotlib.colors` for full list of formats that are
|
||||
accepted. Defaults to 'g'.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image with bboxes and polygons drawn.
|
||||
"""
|
||||
if polygons is not None and self.with_poly:
|
||||
polygons = [polygon.reshape(-1, 2) for polygon in polygons]
|
||||
image = self.get_polygons_image(
|
||||
image, polygons, filling=True, colors=color, alpha=self.alpha)
|
||||
if bboxes is not None and self.with_bbox:
|
||||
image = self.get_bboxes_image(
|
||||
image,
|
||||
bboxes,
|
||||
colors=color,
|
||||
line_width=self.line_width,
|
||||
alpha=self.alpha)
|
||||
return image
|
||||
|
||||
def add_datasample(self,
|
||||
name: str,
|
||||
image: np.ndarray,
|
||||
|
@ -101,79 +138,32 @@ class TextDetLocalVisualizer(Visualizer):
|
|||
and masks. Defaults to 0.3.
|
||||
step (int): Global step value to record. Defaults to 0.
|
||||
"""
|
||||
gt_img_data = None
|
||||
pred_img_data = None
|
||||
|
||||
if (draw_gt and data_sample is not None
|
||||
and 'gt_instances' in data_sample):
|
||||
gt_instances = data_sample.gt_instances
|
||||
|
||||
self.set_image(image)
|
||||
|
||||
if self.with_poly and 'polygons' in gt_instances:
|
||||
gt_polygons = gt_instances.polygons
|
||||
gt_polygons = [
|
||||
gt_polygon.reshape(-1, 2) for gt_polygon in gt_polygons
|
||||
]
|
||||
self.draw_polygons(
|
||||
gt_polygons,
|
||||
alpha=self.alpha,
|
||||
edge_colors=self.gt_color,
|
||||
line_widths=self.line_width)
|
||||
|
||||
if self.with_bbox and 'bboxes' in gt_instances:
|
||||
gt_bboxes = gt_instances.bboxes
|
||||
self.draw_bboxes(
|
||||
gt_bboxes,
|
||||
alpha=self.alpha,
|
||||
edge_colors=self.gt_color,
|
||||
line_widths=self.line_width)
|
||||
|
||||
gt_img_data = self.get_image()
|
||||
|
||||
if draw_pred and data_sample is not None \
|
||||
and 'pred_instances' in data_sample:
|
||||
pred_instances = data_sample.pred_instances
|
||||
pred_instances = pred_instances[
|
||||
pred_instances.scores > pred_score_thr].cpu()
|
||||
|
||||
self.set_image(image)
|
||||
|
||||
if self.with_poly and 'polygons' in pred_instances:
|
||||
pred_polygons = pred_instances.polygons
|
||||
pred_polygons = [
|
||||
pred_polygon.reshape(-1, 2)
|
||||
for pred_polygon in pred_polygons
|
||||
]
|
||||
self.draw_polygons(
|
||||
pred_polygons,
|
||||
alpha=self.alpha,
|
||||
edge_colors=self.pred_color,
|
||||
line_widths=self.line_width)
|
||||
|
||||
if self.with_bbox and 'bboxes' in pred_instances:
|
||||
pred_bboxes = pred_instances.bboxes
|
||||
self.draw_bboxes(
|
||||
pred_bboxes,
|
||||
alpha=self.alpha,
|
||||
edge_colors=self.pred_color,
|
||||
line_widths=self.line_width)
|
||||
|
||||
pred_img_data = self.get_image()
|
||||
|
||||
if gt_img_data is not None and pred_img_data is not None:
|
||||
drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)
|
||||
elif gt_img_data is not None:
|
||||
drawn_img = gt_img_data
|
||||
elif pred_img_data is not None:
|
||||
drawn_img = pred_img_data
|
||||
else:
|
||||
drawn_img = image
|
||||
|
||||
cat_images = []
|
||||
if data_sample is not None:
|
||||
if draw_gt and 'gt_instances' in data_sample:
|
||||
gt_instances = data_sample.gt_instances
|
||||
gt_polygons = gt_instances.get('polygons', None)
|
||||
gt_bboxes = gt_instances.get('bboxes', None)
|
||||
gt_img_data = self._draw_instances(image.copy(), gt_bboxes,
|
||||
gt_polygons, self.gt_color)
|
||||
cat_images.append(gt_img_data)
|
||||
if draw_pred and 'pred_instances' in data_sample:
|
||||
pred_instances = data_sample.pred_instances
|
||||
pred_instances = pred_instances[
|
||||
pred_instances.scores > pred_score_thr].cpu()
|
||||
pred_polygons = pred_instances.get('polygons', None)
|
||||
pred_bboxes = pred_instances.get('bboxes', None)
|
||||
pred_img_data = self._draw_instances(image.copy(), pred_bboxes,
|
||||
pred_polygons,
|
||||
self.pred_color)
|
||||
cat_images.append(pred_img_data)
|
||||
cat_images = self._cat_image(cat_images, axis=1)
|
||||
if cat_images is None:
|
||||
cat_images = image
|
||||
if show:
|
||||
self.show(drawn_img, win_name=name, wait_time=wait_time)
|
||||
self.show(cat_images, win_name=name, wait_time=wait_time)
|
||||
else:
|
||||
self.add_image(name, drawn_img, step)
|
||||
self.add_image(name, cat_images, step)
|
||||
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(drawn_img[..., ::-1], out_file)
|
||||
mmcv.imwrite(cat_images[..., ::-1], out_file)
|
||||
|
|
|
@ -4,14 +4,14 @@ from typing import Dict, Optional, Tuple, Union
|
|||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmengine.visualization import Visualizer
|
||||
|
||||
from mmocr.registry import VISUALIZERS
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_visualizer import BaseLocalVisualizer
|
||||
|
||||
|
||||
@VISUALIZERS.register_module()
|
||||
class TextRecogLocalVisualizer(Visualizer):
|
||||
class TextRecogLocalVisualizer(BaseLocalVisualizer):
|
||||
"""MMOCR Text Detection Local Visualizer.
|
||||
|
||||
Args:
|
||||
|
@ -46,6 +46,30 @@ class TextRecogLocalVisualizer(Visualizer):
|
|||
self.gt_color = gt_color
|
||||
self.pred_color = pred_color
|
||||
|
||||
def _draw_instances(self, image: np.ndarray, text: str) -> np.ndarray:
|
||||
"""Draw text on image.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The image to draw.
|
||||
text (str): The text to draw.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image with text drawn.
|
||||
"""
|
||||
height, width = image.shape[:2]
|
||||
empty_img = np.full_like(image, 255)
|
||||
self.set_image(empty_img)
|
||||
font_size = 0.5 * width / (len(text) + 1)
|
||||
self.draw_texts(
|
||||
text,
|
||||
np.array([width / 2, height / 2]),
|
||||
colors=self.gt_color,
|
||||
font_sizes=font_size,
|
||||
vertical_alignments='center',
|
||||
horizontal_alignments='center')
|
||||
text_image = self.get_image()
|
||||
return text_image
|
||||
|
||||
def add_datasample(self,
|
||||
name: str,
|
||||
image: np.ndarray,
|
||||
|
@ -85,59 +109,28 @@ class TextRecogLocalVisualizer(Visualizer):
|
|||
pred_score_thr (float): Threshold of prediction score. It's not
|
||||
used in this function. Defaults to None.
|
||||
"""
|
||||
gt_img_data = None
|
||||
pred_img_data = None
|
||||
height, width = image.shape[:2]
|
||||
resize_height = 64
|
||||
resize_width = int(1.0 * width / height * resize_height)
|
||||
image = cv2.resize(image, (resize_width, resize_height))
|
||||
|
||||
if image.ndim == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
|
||||
cat_images = [image]
|
||||
if draw_gt and data_sample is not None and 'gt_text' in data_sample:
|
||||
gt_text = data_sample.gt_text.item
|
||||
empty_img = np.full_like(image, 255)
|
||||
self.set_image(empty_img)
|
||||
font_size = 0.5 * resize_width / (len(gt_text) + 1)
|
||||
self.draw_texts(
|
||||
gt_text,
|
||||
np.array([resize_width / 2, resize_height / 2]),
|
||||
colors=self.gt_color,
|
||||
font_sizes=font_size,
|
||||
vertical_alignments='center',
|
||||
horizontal_alignments='center')
|
||||
gt_text_image = self.get_image()
|
||||
gt_img_data = np.concatenate((image, gt_text_image), axis=0)
|
||||
|
||||
cat_images.append(self._draw_instances(image, gt_text))
|
||||
if (draw_pred and data_sample is not None
|
||||
and 'pred_text' in data_sample):
|
||||
pred_text = data_sample.pred_text.item
|
||||
empty_img = np.full_like(image, 255)
|
||||
self.set_image(empty_img)
|
||||
font_size = 0.5 * resize_width / (len(pred_text) + 1)
|
||||
self.draw_texts(
|
||||
pred_text,
|
||||
np.array([resize_width / 2, resize_height / 2]),
|
||||
colors=self.pred_color,
|
||||
font_sizes=font_size,
|
||||
vertical_alignments='center',
|
||||
horizontal_alignments='center')
|
||||
pred_text_image = self.get_image()
|
||||
pred_img_data = np.concatenate((image, pred_text_image), axis=0)
|
||||
|
||||
if gt_img_data is not None and pred_img_data is not None:
|
||||
drawn_img = np.concatenate((gt_img_data, pred_text_image), axis=0)
|
||||
elif gt_img_data is not None:
|
||||
drawn_img = gt_img_data
|
||||
elif pred_img_data is not None:
|
||||
drawn_img = pred_img_data
|
||||
else:
|
||||
drawn_img = image
|
||||
cat_images.append(self._draw_instances(image, pred_text))
|
||||
cat_images = self._cat_image(cat_images, axis=0)
|
||||
|
||||
if show:
|
||||
self.show(drawn_img, win_name=name, wait_time=wait_time)
|
||||
self.show(cat_images, win_name=name, wait_time=wait_time)
|
||||
else:
|
||||
self.add_image(name, drawn_img, step)
|
||||
self.add_image(name, cat_images, step)
|
||||
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(drawn_img[..., ::-1], out_file)
|
||||
mmcv.imwrite(cat_images[..., ::-1], out_file)
|
||||
|
|
|
@ -37,27 +37,26 @@ class TextSpottingLocalVisualizer(BaseLocalVisualizer):
|
|||
should be the same as the number of bboxes.
|
||||
class_names (dict): The class names for bbox labels.
|
||||
is_openset (bool): Whether the dataset is openset. Default: False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image with instances drawn.
|
||||
"""
|
||||
img_shape = image.shape[:2]
|
||||
empty_shape = (img_shape[0], img_shape[1], 3)
|
||||
|
||||
text_image = np.full(empty_shape, 255, dtype=np.uint8)
|
||||
text_image = self.get_labels_image(
|
||||
text_image, labels=texts, bboxes=bboxes)
|
||||
if polygons:
|
||||
polygons = [polygon.reshape(-1, 2) for polygon in polygons]
|
||||
if polygons:
|
||||
image = self._draw_polygons(
|
||||
self, image, polygons, filling=True, colors=self.PALETTE)
|
||||
image = self.get_polygons_image(
|
||||
image, polygons, filling=True, colors=self.PALETTE)
|
||||
text_image = self.get_polygons_image(
|
||||
text_image, polygons, colors=self.PALETTE)
|
||||
else:
|
||||
image = self._draw_bboxes(
|
||||
self, image, bboxes, filling=True, colors=self.PALETTE)
|
||||
|
||||
text_image = np.full(empty_shape, 255, dtype=np.uint8)
|
||||
text_image = self._draw_labels(self, text_image, texts, bboxes)
|
||||
if polygons:
|
||||
text_image = self._draw_polygons(
|
||||
self, text_image, polygons, colors=self.PALETTE)
|
||||
else:
|
||||
text_image = self._draw_bboxes(
|
||||
self, text_image, bboxes, colors=self.PALETTE)
|
||||
image = self.get_bboxes_image(
|
||||
image, bboxes, filling=True, colors=self.PALETTE)
|
||||
text_image = self.get_bboxes_image(
|
||||
text_image, bboxes, colors=self.PALETTE)
|
||||
return np.concatenate([image, text_image], axis=1)
|
||||
|
||||
def add_datasample(self,
|
||||
|
@ -68,43 +67,69 @@ class TextSpottingLocalVisualizer(BaseLocalVisualizer):
|
|||
draw_pred: bool = True,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
pred_score_thr: float = None,
|
||||
pred_score_thr: float = 0.5,
|
||||
out_file: Optional[str] = None,
|
||||
step: int = 0) -> None:
|
||||
gt_img_data = None
|
||||
pred_img_data = None
|
||||
"""Draw datasample and save to all backends.
|
||||
|
||||
- If GT and prediction are plotted at the same time, they are
|
||||
displayed in a stitched image where the left image is the
|
||||
ground truth and the right image is the prediction.
|
||||
- If ``show`` is True, all storage backends are ignored, and
|
||||
the images will be displayed in a local window.
|
||||
- If ``out_file`` is specified, the drawn image will be
|
||||
saved to ``out_file``. This is usually used when the display
|
||||
is not available.
|
||||
|
||||
Args:
|
||||
name (str): The image identifier.
|
||||
image (np.ndarray): The image to draw.
|
||||
data_sample (:obj:`TextSpottingDataSample`, optional):
|
||||
TextDetDataSample which contains gt and prediction. Defaults
|
||||
to None.
|
||||
draw_gt (bool): Whether to draw GT TextDetDataSample.
|
||||
Defaults to True.
|
||||
draw_pred (bool): Whether to draw Predicted TextDetDataSample.
|
||||
Defaults to True.
|
||||
show (bool): Whether to display the drawn image. Default to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
out_file (str): Path to output file. Defaults to None.
|
||||
pred_score_thr (float): The threshold to visualize the bboxes
|
||||
and masks. Defaults to 0.3.
|
||||
step (int): Global step value to record. Defaults to 0.
|
||||
"""
|
||||
cat_images = []
|
||||
|
||||
if draw_gt:
|
||||
gt_bboxes = data_sample.gt_instances.bboxes
|
||||
gt_bboxes = data_sample.gt_instances.get('bboxes', None)
|
||||
gt_texts = data_sample.gt_instances.texts
|
||||
gt_polygons = data_sample.gt_instances.polygons
|
||||
gt_polygons = data_sample.gt_instances.get('polygons', None)
|
||||
gt_img_data = self._draw_instances(image, gt_bboxes, gt_polygons,
|
||||
gt_texts)
|
||||
cat_images.append(gt_img_data)
|
||||
|
||||
if draw_pred:
|
||||
pred_instances = data_sample.pred_instances
|
||||
pred_instances = pred_instances[
|
||||
pred_instances.scores > pred_score_thr].cpu().numpy()
|
||||
pred_bboxes = pred_instances.get('bboxes', None)
|
||||
pred_texts = pred_instances.texts
|
||||
pred_polygons = pred_instances.polygons
|
||||
pred_polygons = pred_instances.get('polygons', None)
|
||||
if pred_bboxes is None:
|
||||
pred_bboxes = [poly2bbox(poly) for poly in pred_polygons]
|
||||
pred_bboxes = np.array(pred_bboxes)
|
||||
pred_img_data = self._draw_instances(image, pred_bboxes,
|
||||
pred_polygons, pred_texts)
|
||||
if gt_img_data is not None and pred_img_data is not None:
|
||||
drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=0)
|
||||
elif gt_img_data is not None:
|
||||
drawn_img = gt_img_data
|
||||
elif pred_img_data is not None:
|
||||
drawn_img = pred_img_data
|
||||
else:
|
||||
drawn_img = image
|
||||
cat_images.append(pred_img_data)
|
||||
|
||||
cat_images = self._cat_image(cat_images, axis=0)
|
||||
if cat_images is None:
|
||||
cat_images = image
|
||||
|
||||
if show:
|
||||
self.show(drawn_img, win_name=name, wait_time=wait_time)
|
||||
self.show(cat_images, win_name=name, wait_time=wait_time)
|
||||
else:
|
||||
self.add_image(name, drawn_img, step)
|
||||
self.add_image(name, cat_images, step)
|
||||
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(drawn_img[..., ::-1], out_file)
|
||||
mmcv.imwrite(cat_images[..., ::-1], out_file)
|
||||
|
|
|
@ -1,890 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import urllib
|
||||
import warnings
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
import mmocr.utils as utils
|
||||
|
||||
|
||||
# TODO remove after KieVisualizer and TextSpotterVisualizer
|
||||
def overlay_mask_img(img, mask):
|
||||
"""Draw mask boundaries on image for visualization.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image.
|
||||
mask (ndarray): The instance mask.
|
||||
|
||||
Returns:
|
||||
img (ndarray): The output image with instance boundaries on it.
|
||||
"""
|
||||
assert isinstance(img, np.ndarray)
|
||||
assert isinstance(mask, np.ndarray)
|
||||
|
||||
contours, _ = cv2.findContours(
|
||||
mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
cv2.drawContours(img, contours, -1, (0, 255, 0), 1)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def show_feature(features, names, to_uint8, out_file=None):
|
||||
"""Visualize a list of feature maps.
|
||||
|
||||
Args:
|
||||
features (list(ndarray)): The feature map list.
|
||||
names (list(str)): The visualized title list.
|
||||
to_uint8 (list(1|0)): The list indicating whether to convent
|
||||
feature maps to uint8.
|
||||
out_file (str): The output file name. If set to None,
|
||||
the output image will be shown without saving.
|
||||
"""
|
||||
assert utils.is_type_list(features, np.ndarray)
|
||||
assert utils.is_type_list(names, str)
|
||||
assert utils.is_type_list(to_uint8, int)
|
||||
assert utils.is_none_or_type(out_file, str)
|
||||
assert utils.equal_len(features, names, to_uint8)
|
||||
|
||||
num = len(features)
|
||||
row = col = math.ceil(math.sqrt(num))
|
||||
|
||||
for i, (f, n) in enumerate(zip(features, names)):
|
||||
plt.subplot(row, col, i + 1)
|
||||
plt.title(n)
|
||||
if to_uint8[i]:
|
||||
f = f.astype(np.uint8)
|
||||
plt.imshow(f)
|
||||
if out_file is None:
|
||||
plt.show()
|
||||
else:
|
||||
plt.savefig(out_file)
|
||||
|
||||
|
||||
def show_img_boundary(img, boundary):
|
||||
"""Show image and instance boundaires.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image.
|
||||
boundary (list[float or int]): The input boundary.
|
||||
"""
|
||||
assert isinstance(img, np.ndarray)
|
||||
assert utils.is_type_list(boundary, (int, float))
|
||||
|
||||
cv2.polylines(
|
||||
img, [np.array(boundary).astype(np.int32).reshape(-1, 1, 2)],
|
||||
True,
|
||||
color=(0, 255, 0),
|
||||
thickness=1)
|
||||
plt.imshow(img)
|
||||
plt.show()
|
||||
|
||||
|
||||
def show_pred_gt(preds,
|
||||
gts,
|
||||
show=False,
|
||||
win_name='',
|
||||
wait_time=0,
|
||||
out_file=None):
|
||||
"""Show detection and ground truth for one image.
|
||||
|
||||
Args:
|
||||
preds (list[list[float]]): The detection boundary list.
|
||||
gts (list[list[float]]): The ground truth boundary list.
|
||||
show (bool): Whether to show the image.
|
||||
win_name (str): The window name.
|
||||
wait_time (int): The value of waitKey param.
|
||||
out_file (str): The filename of the output.
|
||||
"""
|
||||
assert utils.is_2dlist(preds)
|
||||
assert utils.is_2dlist(gts)
|
||||
assert isinstance(show, bool)
|
||||
assert isinstance(win_name, str)
|
||||
assert isinstance(wait_time, int)
|
||||
assert utils.is_none_or_type(out_file, str)
|
||||
|
||||
p_xy = [p for boundary in preds for p in boundary]
|
||||
gt_xy = [g for gt in gts for g in gt]
|
||||
|
||||
max_xy = np.max(np.array(p_xy + gt_xy).reshape(-1, 2), axis=0)
|
||||
|
||||
width = int(max_xy[0]) + 100
|
||||
height = int(max_xy[1]) + 100
|
||||
|
||||
img = np.ones((height, width, 3), np.int8) * 255
|
||||
pred_color = mmcv.color_val('red')
|
||||
gt_color = mmcv.color_val('blue')
|
||||
thickness = 1
|
||||
|
||||
for boundary in preds:
|
||||
cv2.polylines(
|
||||
img, [np.array(boundary).astype(np.int32).reshape(-1, 1, 2)],
|
||||
True,
|
||||
color=pred_color,
|
||||
thickness=thickness)
|
||||
for gt in gts:
|
||||
cv2.polylines(
|
||||
img, [np.array(gt).astype(np.int32).reshape(-1, 1, 2)],
|
||||
True,
|
||||
color=gt_color,
|
||||
thickness=thickness)
|
||||
if show:
|
||||
mmcv.imshow(img, win_name, wait_time)
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(img, out_file)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def imshow_pred_boundary(img,
|
||||
boundaries_with_scores,
|
||||
labels,
|
||||
score_thr=0,
|
||||
boundary_color='blue',
|
||||
text_color='blue',
|
||||
thickness=1,
|
||||
font_scale=0.5,
|
||||
show=True,
|
||||
win_name='',
|
||||
wait_time=0,
|
||||
out_file=None,
|
||||
show_score=False):
|
||||
"""Draw boundaries and class labels (with scores) on an image.
|
||||
|
||||
Args:
|
||||
img (str or ndarray): The image to be displayed.
|
||||
boundaries_with_scores (list[list[float]]): Boundaries with scores.
|
||||
labels (list[int]): Labels of boundaries.
|
||||
score_thr (float): Minimum score of boundaries to be shown.
|
||||
boundary_color (str or tuple or :obj:`Color`): Color of boundaries.
|
||||
text_color (str or tuple or :obj:`Color`): Color of texts.
|
||||
thickness (int): Thickness of lines.
|
||||
font_scale (float): Font scales of texts.
|
||||
show (bool): Whether to show the image.
|
||||
win_name (str): The window name.
|
||||
wait_time (int): Value of waitKey param.
|
||||
out_file (str or None): The filename of the output.
|
||||
show_score (bool): Whether to show text instance score.
|
||||
"""
|
||||
assert isinstance(img, (str, np.ndarray))
|
||||
assert utils.is_2dlist(boundaries_with_scores)
|
||||
assert utils.is_type_list(labels, int)
|
||||
assert utils.equal_len(boundaries_with_scores, labels)
|
||||
if len(boundaries_with_scores) == 0:
|
||||
warnings.warn('0 text found in ' + out_file)
|
||||
return None
|
||||
|
||||
utils.valid_boundary(boundaries_with_scores[0])
|
||||
img = mmcv.imread(img)
|
||||
|
||||
scores = np.array([b[-1] for b in boundaries_with_scores])
|
||||
inds = scores > score_thr
|
||||
boundaries = [boundaries_with_scores[i][:-1] for i in np.where(inds)[0]]
|
||||
scores = [scores[i] for i in np.where(inds)[0]]
|
||||
labels = [labels[i] for i in np.where(inds)[0]]
|
||||
|
||||
boundary_color = mmcv.color_val(boundary_color)
|
||||
text_color = mmcv.color_val(text_color)
|
||||
font_scale = 0.5
|
||||
|
||||
for boundary, score in zip(boundaries, scores):
|
||||
boundary_int = np.array(boundary).astype(np.int32)
|
||||
|
||||
cv2.polylines(
|
||||
img, [boundary_int.reshape(-1, 1, 2)],
|
||||
True,
|
||||
color=boundary_color,
|
||||
thickness=thickness)
|
||||
|
||||
if show_score:
|
||||
label_text = f'{score:.02f}'
|
||||
cv2.putText(img, label_text,
|
||||
(boundary_int[0], boundary_int[1] - 2),
|
||||
cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color)
|
||||
if show:
|
||||
mmcv.imshow(img, win_name, wait_time)
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(img, out_file)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def imshow_text_char_boundary(img,
|
||||
text_quads,
|
||||
boundaries,
|
||||
char_quads,
|
||||
chars,
|
||||
show=False,
|
||||
thickness=1,
|
||||
font_scale=0.5,
|
||||
win_name='',
|
||||
wait_time=-1,
|
||||
out_file=None):
|
||||
"""Draw text boxes and char boxes on img.
|
||||
|
||||
Args:
|
||||
img (str or ndarray): The img to be displayed.
|
||||
text_quads (list[list[int|float]]): The text boxes.
|
||||
boundaries (list[list[int|float]]): The boundary list.
|
||||
char_quads (list[list[list[int|float]]]): A 2d list of char boxes.
|
||||
char_quads[i] is for the ith text, and char_quads[i][j] is the jth
|
||||
char of the ith text.
|
||||
chars (list[list[char]]). The string for each text box.
|
||||
thickness (int): Thickness of lines.
|
||||
font_scale (float): Font scales of texts.
|
||||
show (bool): Whether to show the image.
|
||||
win_name (str): The window name.
|
||||
wait_time (int): Value of waitKey param.
|
||||
out_file (str or None): The filename of the output.
|
||||
"""
|
||||
assert isinstance(img, (np.ndarray, str))
|
||||
assert utils.is_2dlist(text_quads)
|
||||
assert utils.is_2dlist(boundaries)
|
||||
assert utils.is_3dlist(char_quads)
|
||||
assert utils.is_2dlist(chars)
|
||||
assert utils.equal_len(text_quads, char_quads, boundaries)
|
||||
|
||||
img = mmcv.imread(img)
|
||||
char_color = [mmcv.color_val('blue'), mmcv.color_val('green')]
|
||||
text_color = mmcv.color_val('red')
|
||||
text_inx = 0
|
||||
for text_box, boundary, char_box, txt in zip(text_quads, boundaries,
|
||||
char_quads, chars):
|
||||
text_box = np.array(text_box)
|
||||
boundary = np.array(boundary)
|
||||
|
||||
text_box = text_box.reshape(-1, 2).astype(np.int32)
|
||||
cv2.polylines(
|
||||
img, [text_box.reshape(-1, 1, 2)],
|
||||
True,
|
||||
color=text_color,
|
||||
thickness=thickness)
|
||||
if boundary.shape[0] > 0:
|
||||
cv2.polylines(
|
||||
img, [boundary.reshape(-1, 1, 2)],
|
||||
True,
|
||||
color=text_color,
|
||||
thickness=thickness)
|
||||
|
||||
for b in char_box:
|
||||
b = np.array(b)
|
||||
c = char_color[text_inx % 2]
|
||||
b = b.astype(np.int32)
|
||||
cv2.polylines(
|
||||
img, [b.reshape(-1, 1, 2)], True, color=c, thickness=thickness)
|
||||
|
||||
label_text = ''.join(txt)
|
||||
cv2.putText(img, label_text, (text_box[0, 0], text_box[0, 1] - 2),
|
||||
cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color)
|
||||
text_inx = text_inx + 1
|
||||
|
||||
if show:
|
||||
mmcv.imshow(img, win_name, wait_time)
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(img, out_file)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def tile_image(images):
|
||||
"""Combined multiple images to one vertically.
|
||||
|
||||
Args:
|
||||
images (list[np.ndarray]): Images to be combined.
|
||||
"""
|
||||
assert isinstance(images, list)
|
||||
assert len(images) > 0
|
||||
|
||||
for i, _ in enumerate(images):
|
||||
if len(images[i].shape) == 2:
|
||||
images[i] = cv2.cvtColor(images[i], cv2.COLOR_GRAY2BGR)
|
||||
|
||||
widths = [img.shape[1] for img in images]
|
||||
heights = [img.shape[0] for img in images]
|
||||
h, w = sum(heights), max(widths)
|
||||
vis_img = np.zeros((h, w, 3), dtype=np.uint8)
|
||||
|
||||
offset_y = 0
|
||||
for image in images:
|
||||
img_h, img_w = image.shape[:2]
|
||||
vis_img[offset_y:(offset_y + img_h), 0:img_w, :] = image
|
||||
offset_y += img_h
|
||||
|
||||
return vis_img
|
||||
|
||||
|
||||
def imshow_text_label(img,
|
||||
pred_label,
|
||||
gt_label,
|
||||
show=False,
|
||||
win_name='',
|
||||
wait_time=-1,
|
||||
out_file=None):
|
||||
"""Draw predicted texts and ground truth texts on images.
|
||||
|
||||
Args:
|
||||
img (str or np.ndarray): Image filename or loaded image.
|
||||
pred_label (str): Predicted texts.
|
||||
gt_label (str): Ground truth texts.
|
||||
show (bool): Whether to show the image.
|
||||
win_name (str): The window name.
|
||||
wait_time (int): Value of waitKey param.
|
||||
out_file (str): The filename of the output.
|
||||
"""
|
||||
assert isinstance(img, (np.ndarray, str))
|
||||
assert isinstance(pred_label, str)
|
||||
assert isinstance(gt_label, str)
|
||||
assert isinstance(show, bool)
|
||||
assert isinstance(win_name, str)
|
||||
assert isinstance(wait_time, int)
|
||||
|
||||
img = mmcv.imread(img)
|
||||
|
||||
src_h, src_w = img.shape[:2]
|
||||
resize_height = 64
|
||||
resize_width = int(1.0 * src_w / src_h * resize_height)
|
||||
img = cv2.resize(img, (resize_width, resize_height))
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if is_contain_chinese(pred_label):
|
||||
pred_img = draw_texts_by_pil(img, [pred_label], None)
|
||||
else:
|
||||
pred_img = np.ones((h, w, 3), dtype=np.uint8) * 255
|
||||
cv2.putText(pred_img, pred_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.9, (0, 0, 255), 2)
|
||||
images = [pred_img, img]
|
||||
|
||||
if gt_label != '':
|
||||
if is_contain_chinese(gt_label):
|
||||
gt_img = draw_texts_by_pil(img, [gt_label], None)
|
||||
else:
|
||||
gt_img = np.ones((h, w, 3), dtype=np.uint8) * 255
|
||||
cv2.putText(gt_img, gt_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.9, (255, 0, 0), 2)
|
||||
images.append(gt_img)
|
||||
|
||||
img = tile_image(images)
|
||||
|
||||
if show:
|
||||
mmcv.imshow(img, win_name, wait_time)
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(img, out_file)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def imshow_node(img,
|
||||
result,
|
||||
boxes,
|
||||
idx_to_cls={},
|
||||
show=False,
|
||||
win_name='',
|
||||
wait_time=-1,
|
||||
out_file=None):
|
||||
|
||||
img = mmcv.imread(img)
|
||||
h, w = img.shape[:2]
|
||||
|
||||
max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1)
|
||||
node_pred_label = max_idx.numpy().tolist()
|
||||
node_pred_score = max_value.numpy().tolist()
|
||||
|
||||
texts, text_boxes = [], []
|
||||
for i, box in enumerate(boxes):
|
||||
new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]],
|
||||
[box[0], box[3]]]
|
||||
Pts = np.array([new_box], np.int32)
|
||||
cv2.polylines(
|
||||
img, [Pts.reshape((-1, 1, 2))],
|
||||
True,
|
||||
color=(255, 255, 0),
|
||||
thickness=1)
|
||||
x_min = int(min(point[0] for point in new_box))
|
||||
y_min = int(min(point[1] for point in new_box))
|
||||
|
||||
# text
|
||||
pred_label = str(node_pred_label[i])
|
||||
if pred_label in idx_to_cls:
|
||||
pred_label = idx_to_cls[pred_label]
|
||||
pred_score = f'{node_pred_score[i]:.2f}'
|
||||
text = pred_label + '(' + pred_score + ')'
|
||||
texts.append(text)
|
||||
|
||||
# text box
|
||||
font_size = int(
|
||||
min(
|
||||
abs(new_box[3][1] - new_box[0][1]),
|
||||
abs(new_box[1][0] - new_box[0][0])))
|
||||
char_num = len(text)
|
||||
text_box = [
|
||||
x_min * 2, y_min, x_min * 2 + font_size * char_num, y_min,
|
||||
x_min * 2 + font_size * char_num, y_min + font_size, x_min * 2,
|
||||
y_min + font_size
|
||||
]
|
||||
text_boxes.append(text_box)
|
||||
|
||||
pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255
|
||||
pred_img = draw_texts_by_pil(
|
||||
pred_img, texts, text_boxes, draw_box=False, on_ori_img=True)
|
||||
|
||||
vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
|
||||
vis_img[:, :w] = img
|
||||
vis_img[:, w:] = pred_img
|
||||
|
||||
if show:
|
||||
mmcv.imshow(vis_img, win_name, wait_time)
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(vis_img, out_file)
|
||||
|
||||
return vis_img
|
||||
|
||||
|
||||
def gen_color():
|
||||
"""Generate BGR color schemes."""
|
||||
color_list = [(101, 67, 254), (154, 157, 252), (173, 205, 249),
|
||||
(123, 151, 138), (187, 200, 178), (148, 137, 69),
|
||||
(169, 200, 200), (155, 175, 131), (154, 194, 182),
|
||||
(178, 190, 137), (140, 211, 222), (83, 156, 222)]
|
||||
return color_list
|
||||
|
||||
|
||||
def draw_polygons(img, polys):
|
||||
"""Draw polygons on image.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The original image.
|
||||
polys (list[list[float]]): Detected polygons.
|
||||
Return:
|
||||
out_img (np.ndarray): Visualized image.
|
||||
"""
|
||||
dst_img = img.copy()
|
||||
color_list = gen_color()
|
||||
out_img = dst_img
|
||||
for idx, poly in enumerate(polys):
|
||||
poly = np.array(poly).reshape((-1, 1, 2)).astype(np.int32)
|
||||
cv2.drawContours(
|
||||
img,
|
||||
np.array([poly]),
|
||||
-1,
|
||||
color_list[idx % len(color_list)],
|
||||
thickness=cv2.FILLED)
|
||||
out_img = cv2.addWeighted(dst_img, 0.5, img, 0.5, 0)
|
||||
return out_img
|
||||
|
||||
|
||||
def get_optimal_font_scale(text, width):
|
||||
"""Get optimal font scale for cv2.putText.
|
||||
|
||||
Args:
|
||||
text (str): Text in one box.
|
||||
width (int): The box width.
|
||||
"""
|
||||
for scale in reversed(range(0, 60, 1)):
|
||||
textSize = cv2.getTextSize(
|
||||
text,
|
||||
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
||||
fontScale=scale / 10,
|
||||
thickness=1)
|
||||
new_width = textSize[0][0]
|
||||
if new_width <= width:
|
||||
return scale / 10
|
||||
return 1
|
||||
|
||||
|
||||
def draw_texts(img, texts, boxes=None, draw_box=True, on_ori_img=False):
|
||||
"""Draw boxes and texts on empty img.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The original image.
|
||||
texts (list[str]): Recognized texts.
|
||||
boxes (list[list[float]]): Detected bounding boxes.
|
||||
draw_box (bool): Whether draw box or not. If False, draw text only.
|
||||
on_ori_img (bool): If True, draw box and text on input image,
|
||||
else, on a new empty image.
|
||||
Return:
|
||||
out_img (np.ndarray): Visualized image.
|
||||
"""
|
||||
color_list = gen_color()
|
||||
h, w = img.shape[:2]
|
||||
if boxes is None:
|
||||
boxes = [[0, 0, w, 0, w, h, 0, h]]
|
||||
assert len(texts) == len(boxes)
|
||||
|
||||
if on_ori_img:
|
||||
out_img = img
|
||||
else:
|
||||
out_img = np.ones((h, w, 3), dtype=np.uint8) * 255
|
||||
for idx, (box, text) in enumerate(zip(boxes, texts)):
|
||||
if draw_box:
|
||||
new_box = [[x, y] for x, y in zip(box[0::2], box[1::2])]
|
||||
Pts = np.array([new_box], np.int32)
|
||||
cv2.polylines(
|
||||
out_img, [Pts.reshape((-1, 1, 2))],
|
||||
True,
|
||||
color=color_list[idx % len(color_list)],
|
||||
thickness=1)
|
||||
min_x = int(min(box[0::2]))
|
||||
max_y = int(
|
||||
np.mean(np.array(box[1::2])) + 0.2 *
|
||||
(max(box[1::2]) - min(box[1::2])))
|
||||
font_scale = get_optimal_font_scale(
|
||||
text, int(max(box[0::2]) - min(box[0::2])))
|
||||
cv2.putText(out_img, text, (min_x, max_y), cv2.FONT_HERSHEY_SIMPLEX,
|
||||
font_scale, (0, 0, 0), 1)
|
||||
|
||||
return out_img
|
||||
|
||||
|
||||
def draw_texts_by_pil(img,
|
||||
texts,
|
||||
boxes=None,
|
||||
draw_box=True,
|
||||
on_ori_img=False,
|
||||
font_size=None,
|
||||
fill_color=None,
|
||||
draw_pos=None,
|
||||
return_text_size=False):
|
||||
"""Draw boxes and texts on empty image, especially for Chinese.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The original image.
|
||||
texts (list[str]): Recognized texts.
|
||||
boxes (list[list[float]]): Detected bounding boxes.
|
||||
draw_box (bool): Whether draw box or not. If False, draw text only.
|
||||
on_ori_img (bool): If True, draw box and text on input image,
|
||||
else on a new empty image.
|
||||
font_size (int, optional): Size to create a font object for a font.
|
||||
fill_color (tuple(int), optional): Fill color for text.
|
||||
draw_pos (list[tuple(int)], optional): Start point to draw each text.
|
||||
return_text_size (bool): If True, return the list of text size.
|
||||
|
||||
Returns:
|
||||
(np.ndarray, list[tuple]) or np.ndarray: Return a tuple
|
||||
``(out_img, text_sizes)``, where ``out_img`` is the output image
|
||||
with texts drawn on it and ``text_sizes`` are the size of drawing
|
||||
texts. If ``return_text_size`` is False, only the output image will be
|
||||
returned.
|
||||
"""
|
||||
|
||||
color_list = gen_color()
|
||||
h, w = img.shape[:2]
|
||||
if boxes is None:
|
||||
boxes = [[0, 0, w, 0, w, h, 0, h]]
|
||||
if draw_pos is None:
|
||||
draw_pos = [None for _ in texts]
|
||||
assert len(boxes) == len(texts) == len(draw_pos)
|
||||
|
||||
if fill_color is None:
|
||||
fill_color = (0, 0, 0)
|
||||
|
||||
if on_ori_img:
|
||||
out_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||
else:
|
||||
out_img = Image.new('RGB', (w, h), color=(255, 255, 255))
|
||||
out_draw = ImageDraw.Draw(out_img)
|
||||
|
||||
text_sizes = []
|
||||
for idx, (box, text, ori_point) in enumerate(zip(boxes, texts, draw_pos)):
|
||||
if len(text) == 0:
|
||||
continue
|
||||
min_x, max_x = min(box[0::2]), max(box[0::2])
|
||||
min_y, max_y = min(box[1::2]), max(box[1::2])
|
||||
color = tuple(list(color_list[idx % len(color_list)])[::-1])
|
||||
if draw_box:
|
||||
out_draw.line(box, fill=color, width=1)
|
||||
dirname, _ = os.path.split(os.path.abspath(__file__))
|
||||
font_path = os.path.join(dirname, 'font.TTF')
|
||||
if not os.path.exists(font_path):
|
||||
url = ('https://download.openmmlab.com/mmocr/data/font.TTF')
|
||||
print(f'Downloading {url} ...')
|
||||
local_filename, _ = urllib.request.urlretrieve(url)
|
||||
shutil.move(local_filename, font_path)
|
||||
tmp_font_size = font_size
|
||||
if tmp_font_size is None:
|
||||
box_width = max(max_x - min_x, max_y - min_y)
|
||||
tmp_font_size = int(0.9 * box_width / len(text))
|
||||
fnt = ImageFont.truetype(font_path, tmp_font_size)
|
||||
if ori_point is None:
|
||||
ori_point = (min_x + 1, min_y + 1)
|
||||
out_draw.text(ori_point, text, font=fnt, fill=fill_color)
|
||||
text_sizes.append(fnt.getsize(text))
|
||||
|
||||
del out_draw
|
||||
|
||||
out_img = cv2.cvtColor(np.asarray(out_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
if return_text_size:
|
||||
return out_img, text_sizes
|
||||
|
||||
return out_img
|
||||
|
||||
|
||||
def is_contain_chinese(check_str):
|
||||
"""Check whether string contains Chinese or not.
|
||||
|
||||
Args:
|
||||
check_str (str): String to be checked.
|
||||
|
||||
Return True if contains Chinese, else False.
|
||||
"""
|
||||
for ch in check_str:
|
||||
if '\u4e00' <= ch <= '\u9fff':
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def det_recog_show_result(img, end2end_res, out_file=None):
|
||||
"""Draw `result`(boxes and texts) on `img`.
|
||||
|
||||
Args:
|
||||
img (str or np.ndarray): The image to be displayed.
|
||||
end2end_res (dict): Text detect and recognize results.
|
||||
out_file (str): Image path where the visualized image should be saved.
|
||||
Return:
|
||||
out_img (np.ndarray): Visualized image.
|
||||
"""
|
||||
img = mmcv.imread(img)
|
||||
boxes, texts = [], []
|
||||
for res in end2end_res['result']:
|
||||
boxes.append(res['box'])
|
||||
texts.append(res['text'])
|
||||
box_vis_img = draw_polygons(img, boxes)
|
||||
|
||||
if is_contain_chinese(''.join(texts)):
|
||||
text_vis_img = draw_texts_by_pil(img, texts, boxes)
|
||||
else:
|
||||
text_vis_img = draw_texts(img, texts, boxes)
|
||||
|
||||
h, w = img.shape[:2]
|
||||
out_img = np.ones((h, w * 2, 3), dtype=np.uint8)
|
||||
out_img[:, :w, :] = box_vis_img
|
||||
out_img[:, w:, :] = text_vis_img
|
||||
|
||||
if out_file:
|
||||
mmcv.imwrite(out_img, out_file)
|
||||
|
||||
return out_img
|
||||
|
||||
|
||||
def draw_edge_result(img, result, edge_thresh=0.5, keynode_thresh=0.5):
|
||||
"""Draw text and their relationship on empty images.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The original image.
|
||||
result (dict): The result of model forward_test, including:
|
||||
- img_metas (list[dict]): List of meta information dictionary.
|
||||
- nodes (Tensor): Node prediction with size:
|
||||
number_node * node_classes.
|
||||
- edges (Tensor): Edge prediction with size: number_edge * 2.
|
||||
edge_thresh (float): Score threshold for edge classification.
|
||||
keynode_thresh (float): Score threshold for node
|
||||
(``key``) classification.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image with key, value and relation drawn on it.
|
||||
"""
|
||||
|
||||
h, w = img.shape[:2]
|
||||
|
||||
vis_area_width = w // 3 * 2
|
||||
vis_area_height = h
|
||||
dist_key_to_value = vis_area_width // 2
|
||||
dist_pair_to_pair = 30
|
||||
|
||||
bbox_x1 = dist_pair_to_pair
|
||||
bbox_y1 = 0
|
||||
|
||||
new_w = vis_area_width
|
||||
new_h = vis_area_height
|
||||
pred_edge_img = np.ones((new_h, new_w, 3), dtype=np.uint8) * 255
|
||||
|
||||
nodes = result['nodes'].detach().cpu()
|
||||
texts = result['img_metas'][0]['ori_texts']
|
||||
num_nodes = result['nodes'].size(0)
|
||||
edges = result['edges'].detach().cpu()[:, -1].view(num_nodes, num_nodes)
|
||||
|
||||
# (i, j) will be a valid pair
|
||||
# either edge_score(node_i->node_j) > edge_thresh
|
||||
# or edge_score(node_j->node_i) > edge_thresh
|
||||
pairs = (torch.max(edges, edges.T) > edge_thresh).nonzero(as_tuple=True)
|
||||
pairs = (pairs[0].numpy().tolist(), pairs[1].numpy().tolist())
|
||||
|
||||
# 1. "for n1, n2 in zip(*pairs) if n1 < n2":
|
||||
# Only (n1, n2) will be included if n1 < n2 but not (n2, n1), to
|
||||
# avoid duplication.
|
||||
# 2. "(n1, n2) if nodes[n1, 1] > nodes[n1, 2]":
|
||||
# nodes[n1, 1] is the score that this node is predicted as key,
|
||||
# nodes[n1, 2] is the score that this node is predicted as value.
|
||||
# If nodes[n1, 1] > nodes[n1, 2], n1 will be the index of key,
|
||||
# so that n2 will be the index of value.
|
||||
result_pairs = [(n1, n2) if nodes[n1, 1] > nodes[n1, 2] else (n2, n1)
|
||||
for n1, n2 in zip(*pairs) if n1 < n2]
|
||||
|
||||
result_pairs.sort()
|
||||
result_pairs_score = [
|
||||
torch.max(edges[n1, n2], edges[n2, n1]) for n1, n2 in result_pairs
|
||||
]
|
||||
|
||||
key_current_idx = -1
|
||||
pos_current = (-1, -1)
|
||||
newline_flag = False
|
||||
|
||||
key_font_size = 15
|
||||
value_font_size = 15
|
||||
key_font_color = (0, 0, 0)
|
||||
value_font_color = (0, 0, 255)
|
||||
arrow_color = (0, 0, 255)
|
||||
score_color = (0, 255, 0)
|
||||
for pair, pair_score in zip(result_pairs, result_pairs_score):
|
||||
key_idx = pair[0]
|
||||
if nodes[key_idx, 1] < keynode_thresh:
|
||||
continue
|
||||
if key_idx != key_current_idx:
|
||||
# move y-coords down for a new key
|
||||
bbox_y1 += 10
|
||||
# enlarge blank area to show key-value info
|
||||
if newline_flag:
|
||||
bbox_x1 += vis_area_width
|
||||
tmp_img = np.ones(
|
||||
(new_h, new_w + vis_area_width, 3), dtype=np.uint8) * 255
|
||||
tmp_img[:new_h, :new_w] = pred_edge_img
|
||||
pred_edge_img = tmp_img
|
||||
new_w += vis_area_width
|
||||
newline_flag = False
|
||||
bbox_y1 = 10
|
||||
key_text = texts[key_idx]
|
||||
key_pos = (bbox_x1, bbox_y1)
|
||||
value_idx = pair[1]
|
||||
value_text = texts[value_idx]
|
||||
value_pos = (bbox_x1 + dist_key_to_value, bbox_y1)
|
||||
if key_idx != key_current_idx:
|
||||
# draw text for a new key
|
||||
key_current_idx = key_idx
|
||||
pred_edge_img, text_sizes = draw_texts_by_pil(
|
||||
pred_edge_img, [key_text],
|
||||
draw_box=False,
|
||||
on_ori_img=True,
|
||||
font_size=key_font_size,
|
||||
fill_color=key_font_color,
|
||||
draw_pos=[key_pos],
|
||||
return_text_size=True)
|
||||
pos_right_bottom = (key_pos[0] + text_sizes[0][0],
|
||||
key_pos[1] + text_sizes[0][1])
|
||||
pos_current = (pos_right_bottom[0] + 5, bbox_y1 + 10)
|
||||
pred_edge_img = cv2.arrowedLine(
|
||||
pred_edge_img, (pos_right_bottom[0] + 5, bbox_y1 + 10),
|
||||
(bbox_x1 + dist_key_to_value - 5, bbox_y1 + 10), arrow_color,
|
||||
1)
|
||||
score_pos_x = int(
|
||||
(pos_right_bottom[0] + bbox_x1 + dist_key_to_value) / 2.)
|
||||
score_pos_y = bbox_y1 + 10 - int(key_font_size * 0.3)
|
||||
else:
|
||||
# draw arrow from key to value
|
||||
if newline_flag:
|
||||
tmp_img = np.ones((new_h + dist_pair_to_pair, new_w, 3),
|
||||
dtype=np.uint8) * 255
|
||||
tmp_img[:new_h, :new_w] = pred_edge_img
|
||||
pred_edge_img = tmp_img
|
||||
new_h += dist_pair_to_pair
|
||||
pred_edge_img = cv2.arrowedLine(pred_edge_img, pos_current,
|
||||
(bbox_x1 + dist_key_to_value - 5,
|
||||
bbox_y1 + 10), arrow_color, 1)
|
||||
score_pos_x = int(
|
||||
(pos_current[0] + bbox_x1 + dist_key_to_value - 5) / 2.)
|
||||
score_pos_y = int((pos_current[1] + bbox_y1 + 10) / 2.)
|
||||
# draw edge score
|
||||
cv2.putText(pred_edge_img, f'{pair_score:.2f}',
|
||||
(score_pos_x, score_pos_y), cv2.FONT_HERSHEY_COMPLEX, 0.4,
|
||||
score_color)
|
||||
# draw text for value
|
||||
pred_edge_img = draw_texts_by_pil(
|
||||
pred_edge_img, [value_text],
|
||||
draw_box=False,
|
||||
on_ori_img=True,
|
||||
font_size=value_font_size,
|
||||
fill_color=value_font_color,
|
||||
draw_pos=[value_pos],
|
||||
return_text_size=False)
|
||||
bbox_y1 += dist_pair_to_pair
|
||||
if bbox_y1 + dist_pair_to_pair >= new_h:
|
||||
newline_flag = True
|
||||
|
||||
return pred_edge_img
|
||||
|
||||
|
||||
def imshow_edge(img,
|
||||
result,
|
||||
boxes,
|
||||
show=False,
|
||||
win_name='',
|
||||
wait_time=-1,
|
||||
out_file=None):
|
||||
"""Display the prediction results of the nodes and edges of the KIE model.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The original image.
|
||||
result (dict): The result of model forward_test, including:
|
||||
- img_metas (list[dict]): List of meta information dictionary.
|
||||
- nodes (Tensor): Node prediction with size: \
|
||||
number_node * node_classes.
|
||||
- edges (Tensor): Edge prediction with size: number_edge * 2.
|
||||
boxes (list): The text boxes corresponding to the nodes.
|
||||
show (bool): Whether to show the image. Default: False.
|
||||
win_name (str): The window name. Default: ''
|
||||
wait_time (float): Value of waitKey param. Default: 0.
|
||||
out_file (str or None): The filename to write the image.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image with key, value and relation drawn on it.
|
||||
"""
|
||||
img = mmcv.imread(img)
|
||||
h, w = img.shape[:2]
|
||||
color_list = gen_color()
|
||||
|
||||
for i, box in enumerate(boxes):
|
||||
new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]],
|
||||
[box[0], box[3]]]
|
||||
Pts = np.array([new_box], np.int32)
|
||||
cv2.polylines(
|
||||
img, [Pts.reshape((-1, 1, 2))],
|
||||
True,
|
||||
color=color_list[i % len(color_list)],
|
||||
thickness=1)
|
||||
|
||||
pred_img_h = h
|
||||
pred_img_w = w
|
||||
|
||||
pred_edge_img = draw_edge_result(img, result)
|
||||
pred_img_h = max(pred_img_h, pred_edge_img.shape[0])
|
||||
pred_img_w += pred_edge_img.shape[1]
|
||||
|
||||
vis_img = np.zeros((pred_img_h, pred_img_w, 3), dtype=np.uint8)
|
||||
vis_img[:h, :w] = img
|
||||
vis_img[:, w:] = 255
|
||||
|
||||
height_t, width_t = pred_edge_img.shape[:2]
|
||||
vis_img[:height_t, w:(w + width_t)] = pred_edge_img
|
||||
|
||||
if show:
|
||||
mmcv.imshow(vis_img, win_name, wait_time)
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(vis_img, out_file)
|
||||
res_dic = {
|
||||
'boxes': boxes,
|
||||
'nodes': result['nodes'].detach().cpu(),
|
||||
'edges': result['edges'].detach().cpu(),
|
||||
'metas': result['img_metas'][0]
|
||||
}
|
||||
mmengine.dump(res_dic, f'{out_file}_res.pkl')
|
||||
|
||||
return vis_img
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mmocr.visualization import BaseLocalVisualizer
|
||||
|
||||
|
||||
class TestBaseLocalVisualizer(TestCase):
|
||||
|
||||
def test_get_labels_image(self):
|
||||
labels = ['a', 'b', 'c']
|
||||
image = np.zeros((40, 40, 3), dtype=np.uint8)
|
||||
bboxes = np.array([[0, 0, 10, 10], [10, 10, 20, 20], [20, 20, 30, 30]])
|
||||
labels_image = BaseLocalVisualizer().get_labels_image(
|
||||
image,
|
||||
labels,
|
||||
bboxes=bboxes,
|
||||
auto_font_size=True,
|
||||
colors=['r', 'r', 'r', 'r'])
|
||||
self.assertEqual(labels_image.shape, (40, 40, 3))
|
||||
|
||||
def test_get_polygons_image(self):
|
||||
polygons = [np.array([0, 0, 10, 10, 20, 20, 30, 30]).reshape(-1, 2)]
|
||||
image = np.zeros((40, 40, 3), dtype=np.uint8)
|
||||
polygons_image = BaseLocalVisualizer().get_polygons_image(
|
||||
image, polygons, colors=['r', 'r', 'r', 'r'])
|
||||
self.assertEqual(polygons_image.shape, (40, 40, 3))
|
||||
|
||||
polygons_image = BaseLocalVisualizer().get_polygons_image(
|
||||
image, polygons, colors=['r', 'r', 'r', 'r'], filling=True)
|
||||
self.assertEqual(polygons_image.shape, (40, 40, 3))
|
||||
|
||||
def test_get_bboxes_image(self):
|
||||
bboxes = np.array([[0, 0, 10, 10], [10, 10, 20, 20], [20, 20, 30, 30]])
|
||||
image = np.zeros((40, 40, 3), dtype=np.uint8)
|
||||
bboxes_image = BaseLocalVisualizer().get_bboxes_image(
|
||||
image, bboxes, colors=['r', 'r', 'r', 'r'])
|
||||
self.assertEqual(bboxes_image.shape, (40, 40, 3))
|
||||
|
||||
bboxes_image = BaseLocalVisualizer().get_bboxes_image(
|
||||
image, bboxes, colors=['r', 'r', 'r', 'r'], filling=True)
|
||||
self.assertEqual(bboxes_image.shape, (40, 40, 3))
|
||||
|
||||
def test_cat_images(self):
|
||||
image1 = np.zeros((40, 40, 3), dtype=np.uint8)
|
||||
image2 = np.zeros((40, 40, 3), dtype=np.uint8)
|
||||
image = BaseLocalVisualizer()._cat_image([image1, image2], axis=1)
|
||||
self.assertEqual(image.shape, (40, 80, 3))
|
||||
|
||||
image = BaseLocalVisualizer()._cat_image([], axis=0)
|
||||
self.assertIsNone(image)
|
||||
|
||||
image = BaseLocalVisualizer()._cat_image([image1, None], axis=0)
|
||||
self.assertEqual(image.shape, (40, 40, 3))
|
|
@ -105,6 +105,21 @@ class TestTextKIELocalVisualizer(unittest.TestCase):
|
|||
out_file=out_file)
|
||||
self._assert_image_and_shape(out_file, (h, w * 4, c))
|
||||
|
||||
visualizer = KIELocalVisualizer(is_openset=False)
|
||||
visualizer.dataset_meta = dict(category=[
|
||||
dict(id=0, name='bg'),
|
||||
dict(id=1, name='key'),
|
||||
dict(id=2, name='value'),
|
||||
dict(id=3, name='other')
|
||||
])
|
||||
visualizer.add_datasample(
|
||||
'image',
|
||||
image,
|
||||
self.data_sample,
|
||||
draw_pred=False,
|
||||
out_file=out_file)
|
||||
self._assert_image_and_shape(out_file, (h, w * 3, c))
|
||||
|
||||
def _assert_image_and_shape(self, out_file, out_shape):
|
||||
self.assertTrue(osp.exists(out_file))
|
||||
drawn_img = cv2.imread(out_file)
|
||||
|
|
|
@ -101,6 +101,10 @@ class TestTextDetLocalVisualizer(unittest.TestCase):
|
|||
out_file=out_file)
|
||||
self._assert_image_and_shape(out_file, (h, w, c))
|
||||
|
||||
det_local_visualizer.add_datasample(
|
||||
'image', image, None, out_file=out_file)
|
||||
self._assert_image_and_shape(out_file, (h, w, c))
|
||||
|
||||
def _assert_image_and_shape(self, out_file, out_shape):
|
||||
self.assertTrue(osp.exists(out_file))
|
||||
drawn_img = cv2.imread(out_file)
|
||||
|
|
|
@ -46,7 +46,7 @@ class TestTextDetLocalVisualizer(unittest.TestCase):
|
|||
draw_pred=False)
|
||||
self._assert_image_and_shape(out_file, (h * 2, w, 3))
|
||||
|
||||
# draw_gt = True + gt_sample + pred_sample
|
||||
# draw_gt = True
|
||||
recog_local_visualizer.add_datasample(
|
||||
'image',
|
||||
image,
|
||||
|
@ -56,7 +56,13 @@ class TestTextDetLocalVisualizer(unittest.TestCase):
|
|||
draw_pred=True)
|
||||
self._assert_image_and_shape(out_file, (h * 3, w, 3))
|
||||
|
||||
# draw_gt = False + gt_sample + pred_sample
|
||||
# draw_gt = False
|
||||
recog_local_visualizer.add_datasample(
|
||||
'image', image, data_sample, draw_gt=False, out_file=out_file)
|
||||
self._assert_image_and_shape(out_file, (h * 2, w, 3))
|
||||
|
||||
# gray image
|
||||
image = np.random.randint(0, 256, size=(h, w)).astype('uint8')
|
||||
recog_local_visualizer.add_datasample(
|
||||
'image', image, data_sample, draw_gt=False, out_file=out_file)
|
||||
self._assert_image_and_shape(out_file, (h * 2, w, 3))
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.structures import InstanceData
|
||||
|
||||
from mmocr.structures import TextDetDataSample
|
||||
from mmocr.utils import bbox2poly
|
||||
from mmocr.visualization import TextSpottingLocalVisualizer
|
||||
|
||||
|
||||
class TestTextKIELocalVisualizer(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
h, w = 12, 10
|
||||
self.image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8')
|
||||
# gt_instances
|
||||
data_sample = TextDetDataSample()
|
||||
gt_instances_data = dict(
|
||||
bboxes=self._rand_bboxes(5, h, w),
|
||||
polygons=self._rand_polys(5, h, w),
|
||||
labels=torch.zeros(5, ),
|
||||
texts=['text1', 'text2', 'text3', 'text4', 'text5'])
|
||||
gt_instances = InstanceData(**gt_instances_data)
|
||||
data_sample.gt_instances = gt_instances
|
||||
|
||||
pred_instances_data = dict(
|
||||
bboxes=self._rand_bboxes(5, h, w),
|
||||
labels=torch.zeros(5, ),
|
||||
scores=torch.rand((5, )),
|
||||
texts=['text1', 'text2', 'text3', 'text4', 'text5'])
|
||||
pred_instances = InstanceData(**pred_instances_data)
|
||||
data_sample.pred_instances = pred_instances
|
||||
data_sample = data_sample.numpy()
|
||||
self.data_sample = data_sample
|
||||
|
||||
@staticmethod
|
||||
def _rand_bboxes(num_boxes, h, w):
|
||||
cx, cy, bw, bh = torch.rand(num_boxes, 4).T
|
||||
|
||||
tl_x = ((cx * w) - (w * bw / 2)).clamp(0, w).unsqueeze(0)
|
||||
tl_y = ((cy * h) - (h * bh / 2)).clamp(0, h).unsqueeze(0)
|
||||
br_x = ((cx * w) + (w * bw / 2)).clamp(0, w).unsqueeze(0)
|
||||
br_y = ((cy * h) + (h * bh / 2)).clamp(0, h).unsqueeze(0)
|
||||
|
||||
bboxes = torch.cat([tl_x, tl_y, br_x, br_y], dim=0).T
|
||||
|
||||
return bboxes
|
||||
|
||||
def _rand_polys(self, num_bboxes, h, w):
|
||||
bboxes = self._rand_bboxes(num_bboxes, h, w)
|
||||
bboxes = bboxes.tolist()
|
||||
polys = [bbox2poly(bbox) for bbox in bboxes]
|
||||
return polys
|
||||
|
||||
def test_add_datasample(self):
|
||||
image = self.image
|
||||
h, w, c = image.shape
|
||||
|
||||
visualizer = TextSpottingLocalVisualizer()
|
||||
visualizer.add_datasample('image', image, self.data_sample)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# test out
|
||||
out_file = osp.join(tmp_dir, 'out_file.jpg')
|
||||
visualizer.add_datasample(
|
||||
'image',
|
||||
image,
|
||||
self.data_sample,
|
||||
out_file=out_file,
|
||||
draw_gt=False,
|
||||
draw_pred=False)
|
||||
self._assert_image_and_shape(out_file, (h, w, c))
|
||||
|
||||
visualizer.add_datasample(
|
||||
'image', image, self.data_sample, out_file=out_file)
|
||||
self._assert_image_and_shape(out_file, (h * 2, w * 2, c))
|
||||
|
||||
visualizer.add_datasample(
|
||||
'image',
|
||||
image,
|
||||
self.data_sample,
|
||||
draw_gt=False,
|
||||
out_file=out_file)
|
||||
self._assert_image_and_shape(out_file, (h, w * 2, c))
|
||||
|
||||
visualizer.add_datasample(
|
||||
'image',
|
||||
image,
|
||||
self.data_sample,
|
||||
draw_pred=False,
|
||||
out_file=out_file)
|
||||
self._assert_image_and_shape(out_file, (h, w * 2, c))
|
||||
bboxes = self.data_sample.pred_instances.pop('bboxes')
|
||||
bboxes = bboxes.tolist()
|
||||
polys = [bbox2poly(bbox) for bbox in bboxes]
|
||||
self.data_sample.pred_instances.polygons = polys
|
||||
visualizer.add_datasample(
|
||||
'image',
|
||||
image,
|
||||
self.data_sample,
|
||||
draw_gt=False,
|
||||
out_file=out_file)
|
||||
self._assert_image_and_shape(out_file, (h, w * 2, c))
|
||||
|
||||
def _assert_image_and_shape(self, out_file, out_shape):
|
||||
self.assertTrue(osp.exists(out_file))
|
||||
drawn_img = cv2.imread(out_file)
|
||||
self.assertTrue(drawn_img.shape == out_shape)
|
Loading…
Reference in New Issue