mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Feature] add a new argument font_properties to set a specific font file in order to draw Chinese characters properly (#1709)
* [Feature] add new argument font_properties to set specific font file in order to draw Chinese characters properly * update the minimum mmengine version * add docstr
This commit is contained in:
parent
0894178343
commit
62d440fe8e
@ -193,6 +193,6 @@ MMOCR has different version requirements on MMEngine, MMCV and MMDetection at ea
|
|||||||
|
|
||||||
| MMOCR | MMEngine | MMCV | MMDetection |
|
| MMOCR | MMEngine | MMCV | MMDetection |
|
||||||
| -------------- | --------------------------- | -------------------------- | --------------------------- |
|
| -------------- | --------------------------- | -------------------------- | --------------------------- |
|
||||||
| dev-1.x | 0.5.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc5 \<= mmdet \< 3.1.0 |
|
| dev-1.x | 0.6.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc5 \<= mmdet \< 3.1.0 |
|
||||||
| 1.0.0rc\[4-5\] | 0.1.0 \<= mmengine \< 1.0.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
|
| 1.0.0rc\[4-5\] | 0.1.0 \<= mmengine \< 1.0.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
|
||||||
| 1.0.0rc\[0-3\] | 0.0.0 \<= mmengine \< 0.2.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
|
| 1.0.0rc\[0-3\] | 0.0.0 \<= mmengine \< 0.2.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
|
||||||
|
@ -194,6 +194,6 @@ docker run --gpus all --shm-size=8g -it -v {实际数据目录}:/mmocr/data mmoc
|
|||||||
|
|
||||||
| MMOCR | MMEngine | MMCV | MMDetection |
|
| MMOCR | MMEngine | MMCV | MMDetection |
|
||||||
| -------------- | --------------------------- | -------------------------- | --------------------------- |
|
| -------------- | --------------------------- | -------------------------- | --------------------------- |
|
||||||
| dev-1.x | 0.5.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc5 \<= mmdet \< 3.1.0 |
|
| dev-1.x | 0.6.0 \<= mmengine \< 1.0.0 | 2.0.0rc4 \<= mmcv \< 2.1.0 | 3.0.0rc5 \<= mmdet \< 3.1.0 |
|
||||||
| 1.0.0rc\[4-5\] | 0.1.0 \<= mmengine \< 1.0.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
|
| 1.0.0rc\[4-5\] | 0.1.0 \<= mmengine \< 1.0.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
|
||||||
| 1.0.0rc\[0-3\] | 0.0.0 \<= mmengine \< 0.2.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
|
| 1.0.0rc\[0-3\] | 0.0.0 \<= mmengine \< 0.2.0 | 2.0.0rc1 \<= mmcv \< 2.1.0 | 3.0.0rc0 \<= mmdet \< 3.1.0 |
|
||||||
|
@ -16,7 +16,7 @@ mmcv_minimum_version = '2.0.0rc4'
|
|||||||
mmcv_maximum_version = '2.1.0'
|
mmcv_maximum_version = '2.1.0'
|
||||||
mmcv_version = digit_version(mmcv.__version__)
|
mmcv_version = digit_version(mmcv.__version__)
|
||||||
if mmengine is not None:
|
if mmengine is not None:
|
||||||
mmengine_minimum_version = '0.5.0'
|
mmengine_minimum_version = '0.6.0'
|
||||||
mmengine_maximum_version = '1.0.0'
|
mmengine_maximum_version = '1.0.0'
|
||||||
mmengine_version = digit_version(mmengine.__version__)
|
mmengine_version = digit_version(mmengine.__version__)
|
||||||
|
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import math
|
import math
|
||||||
from typing import List, Sequence, Union
|
from typing import List, Optional, Sequence, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from matplotlib.font_manager import FontProperties
|
||||||
from mmengine.visualization import Visualizer
|
from mmengine.visualization import Visualizer
|
||||||
|
|
||||||
from mmocr.registry import VISUALIZERS
|
from mmocr.registry import VISUALIZERS
|
||||||
@ -27,6 +28,18 @@ class BaseLocalVisualizer(Visualizer):
|
|||||||
Defaults to empty dict.
|
Defaults to empty dict.
|
||||||
is_openset (bool, optional): Whether the visualizer is used in
|
is_openset (bool, optional): Whether the visualizer is used in
|
||||||
OpenSet. Defaults to False.
|
OpenSet. Defaults to False.
|
||||||
|
font_families (Union[str, List[str]]): The font families of labels.
|
||||||
|
Defaults to 'sans-serif'.
|
||||||
|
font_properties (Union[str, FontProperties], optional):
|
||||||
|
The font properties of texts. The format should be a path str
|
||||||
|
to font file or a `font_manager.FontProperties()` object.
|
||||||
|
If you want to draw Chinese texts, you need to prepare
|
||||||
|
a font file that can show Chinese characters properly.
|
||||||
|
For example: `simhei.ttf`,`simsun.ttc`,`simkai.ttf` and so on.
|
||||||
|
Then set font_properties=matplotlib.font_manager.FontProperties
|
||||||
|
(fname='path/to/font_file') or font_properties='path/to/font_file'
|
||||||
|
This function need mmengine version >=0.6.0.
|
||||||
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
PALETTE = [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230),
|
PALETTE = [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230),
|
||||||
(106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70),
|
(106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70),
|
||||||
@ -53,19 +66,36 @@ class BaseLocalVisualizer(Visualizer):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
name: str = 'visualizer',
|
name: str = 'visualizer',
|
||||||
font_families: Union[str, List[str]] = 'sans-serif',
|
font_families: Union[str, List[str]] = 'sans-serif',
|
||||||
|
font_properties: Optional[Union[str, FontProperties]] = None,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
super().__init__(name=name, **kwargs)
|
super().__init__(name=name, **kwargs)
|
||||||
self.font_families = font_families
|
self.font_families = font_families
|
||||||
|
self.font_properties = self._set_font_properties(font_properties)
|
||||||
|
|
||||||
def get_labels_image(self,
|
def _set_font_properties(self,
|
||||||
image: np.ndarray,
|
fp: Optional[Union[str, FontProperties]] = None):
|
||||||
labels: Union[np.ndarray, torch.Tensor],
|
if fp is None:
|
||||||
bboxes: Union[np.ndarray, torch.Tensor],
|
return None
|
||||||
colors: Union[str, Sequence[str]] = 'k',
|
elif isinstance(fp, str):
|
||||||
font_size: Union[int, float] = 10,
|
return FontProperties(fname=fp)
|
||||||
auto_font_size: bool = False,
|
elif isinstance(fp, FontProperties):
|
||||||
font_families: Union[str, List[str]] = 'sans-serif'
|
return fp
|
||||||
) -> np.ndarray:
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'font_properties argument type should be'
|
||||||
|
' `str` or `matplotlib.font_manager.FontProperties`')
|
||||||
|
|
||||||
|
def get_labels_image(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
labels: Union[np.ndarray, torch.Tensor],
|
||||||
|
bboxes: Union[np.ndarray, torch.Tensor],
|
||||||
|
colors: Union[str, Sequence[str]] = 'k',
|
||||||
|
font_size: Union[int, float] = 10,
|
||||||
|
auto_font_size: bool = False,
|
||||||
|
font_families: Union[str, List[str]] = 'sans-serif',
|
||||||
|
font_properties: Optional[Union[str, FontProperties]] = None
|
||||||
|
) -> np.ndarray:
|
||||||
"""Draw labels on image.
|
"""Draw labels on image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -84,6 +114,17 @@ class BaseLocalVisualizer(Visualizer):
|
|||||||
Defaults to False.
|
Defaults to False.
|
||||||
font_families (Union[str, List[str]]): The font families of labels.
|
font_families (Union[str, List[str]]): The font families of labels.
|
||||||
Defaults to 'sans-serif'.
|
Defaults to 'sans-serif'.
|
||||||
|
font_properties (Union[str, FontProperties], optional):
|
||||||
|
The font properties of texts. The format should be a path str
|
||||||
|
to font file or a `font_manager.FontProperties()` object.
|
||||||
|
If you want to draw Chinese texts, you need to prepare
|
||||||
|
a font file that can show Chinese characters properly.
|
||||||
|
For example: `simhei.ttf`,`simsun.ttc`,`simkai.ttf` and so on.
|
||||||
|
Then set font_properties=matplotlib.font_manager.FontProperties
|
||||||
|
(fname='path/to/font_file') or
|
||||||
|
font_properties='path/to/font_file'.
|
||||||
|
This function need mmengine version >=0.6.0.
|
||||||
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
if not labels and not bboxes:
|
if not labels and not bboxes:
|
||||||
return image
|
return image
|
||||||
@ -102,7 +143,8 @@ class BaseLocalVisualizer(Visualizer):
|
|||||||
horizontal_alignments='center',
|
horizontal_alignments='center',
|
||||||
colors='k',
|
colors='k',
|
||||||
font_sizes=font_size,
|
font_sizes=font_size,
|
||||||
font_families=font_families)
|
font_families=font_families,
|
||||||
|
font_properties=font_properties)
|
||||||
return self.get_image()
|
return self.get_image()
|
||||||
|
|
||||||
def get_polygons_image(self,
|
def get_polygons_image(self,
|
||||||
|
@ -90,14 +90,16 @@ class KIELocalVisualizer(BaseLocalVisualizer):
|
|||||||
colors='k',
|
colors='k',
|
||||||
horizontal_alignments='center',
|
horizontal_alignments='center',
|
||||||
vertical_alignments='center',
|
vertical_alignments='center',
|
||||||
font_families=self.font_families)
|
font_families=self.font_families,
|
||||||
|
font_properties=self.font_properties)
|
||||||
if val_texts:
|
if val_texts:
|
||||||
self.draw_texts(
|
self.draw_texts(
|
||||||
val_texts, (bboxes[val_index, :2] + bboxes[val_index, 2:]) / 2,
|
val_texts, (bboxes[val_index, :2] + bboxes[val_index, 2:]) / 2,
|
||||||
colors='k',
|
colors='k',
|
||||||
horizontal_alignments='center',
|
horizontal_alignments='center',
|
||||||
vertical_alignments='center',
|
vertical_alignments='center',
|
||||||
font_families=self.font_families)
|
font_families=self.font_families,
|
||||||
|
font_properties=self.font_properties)
|
||||||
self.draw_arrows(
|
self.draw_arrows(
|
||||||
x_data,
|
x_data,
|
||||||
y_data,
|
y_data,
|
||||||
@ -153,7 +155,11 @@ class KIELocalVisualizer(BaseLocalVisualizer):
|
|||||||
|
|
||||||
text_image = np.full(empty_shape, 255, dtype=np.uint8)
|
text_image = np.full(empty_shape, 255, dtype=np.uint8)
|
||||||
text_image = self.get_labels_image(
|
text_image = self.get_labels_image(
|
||||||
text_image, texts, bboxes, font_families=self.font_families)
|
text_image,
|
||||||
|
texts,
|
||||||
|
bboxes,
|
||||||
|
font_families=self.font_families,
|
||||||
|
font_properties=self.font_properties)
|
||||||
|
|
||||||
classes_image = np.full(empty_shape, 255, dtype=np.uint8)
|
classes_image = np.full(empty_shape, 255, dtype=np.uint8)
|
||||||
bbox_classes = [class_names[int(i)]['name'] for i in bbox_labels]
|
bbox_classes = [class_names[int(i)]['name'] for i in bbox_labels]
|
||||||
@ -161,7 +167,8 @@ class KIELocalVisualizer(BaseLocalVisualizer):
|
|||||||
classes_image,
|
classes_image,
|
||||||
bbox_classes,
|
bbox_classes,
|
||||||
bboxes,
|
bboxes,
|
||||||
font_families=self.font_families)
|
font_families=self.font_families,
|
||||||
|
font_properties=self.font_properties)
|
||||||
if polygons:
|
if polygons:
|
||||||
polygons = [polygon.reshape(-1, 2) for polygon in polygons]
|
polygons = [polygon.reshape(-1, 2) for polygon in polygons]
|
||||||
image = self.get_polygons_image(
|
image = self.get_polygons_image(
|
||||||
|
@ -68,7 +68,8 @@ class TextRecogLocalVisualizer(BaseLocalVisualizer):
|
|||||||
font_sizes=font_size,
|
font_sizes=font_size,
|
||||||
vertical_alignments='center',
|
vertical_alignments='center',
|
||||||
horizontal_alignments='center',
|
horizontal_alignments='center',
|
||||||
font_families=self.font_families)
|
font_families=self.font_families,
|
||||||
|
font_properties=self.font_properties)
|
||||||
text_image = self.get_image()
|
text_image = self.get_image()
|
||||||
return text_image
|
return text_image
|
||||||
|
|
||||||
|
@ -49,7 +49,8 @@ class TextSpottingLocalVisualizer(BaseLocalVisualizer):
|
|||||||
text_image,
|
text_image,
|
||||||
labels=texts,
|
labels=texts,
|
||||||
bboxes=bboxes,
|
bboxes=bboxes,
|
||||||
font_families=self.font_families)
|
font_families=self.font_families,
|
||||||
|
font_properties=self.font_properties)
|
||||||
if polygons:
|
if polygons:
|
||||||
polygons = [polygon.reshape(-1, 2) for polygon in polygons]
|
polygons = [polygon.reshape(-1, 2) for polygon in polygons]
|
||||||
image = self.get_polygons_image(
|
image = self.get_polygons_image(
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
mmcv>==2.0.0rc4,<2.1.0
|
mmcv>==2.0.0rc4,<2.1.0
|
||||||
mmdet>=3.0.0rc5,<3.1.0
|
mmdet>=3.0.0rc5,<3.1.0
|
||||||
mmengine>= 0.5.0, <1.0.0
|
mmengine>= 0.6.0, <1.0.0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user