mirror of https://github.com/open-mmlab/mmocr.git
[Feature] Support openset kie (#498)
* add openset kie dataset * updare readme * add anno convert script * update docstring * update script * add & update docstring * fix typo * update docstring formatpull/574/head
parent
9f42d78db7
commit
a50b0c9fb9
|
@ -23,3 +23,17 @@
|
|||
| :--------------------------------------------------------------------: | :--------------: | :------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| [sdmgr_unet16](/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py) | Visual + Textual | 0.888 | [model](https://download.openmmlab.com/mmocr/kie/sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth) \| [log](https://download.openmmlab.com/mmocr/kie/sdmgr/20210520_132236.log.json) |
|
||||
| [sdmgr_novisual](/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py) | Textual | 0.870 | [model](https://download.openmmlab.com/mmocr/kie/sdmgr/sdmgr_novisual_60e_wildreceipt_20210517-a44850da.pth) \| [log](https://download.openmmlab.com/mmocr/kie/sdmgr/20210517_205829.log.json) |
|
||||
|
||||
### WildReceiptOpenset
|
||||
|
||||
| Method | Modality | Edge F1-Score | Node Macro F1-Score | Node Micro F1-Score | Download |
|
||||
| :-------: | :----------: | :--------: | :--------: | :--------: | :--------: |
|
||||
| [sdmgr_novisual](/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt_openset.py) | Textual | 0.786 | 0.926 | 0.935 | [model](https://download.openmmlab.com/mmocr/kie/sdmgr/sdmgr_novisual_60e_wildreceipt_openset_20210917-d236b3ea.pth) \| [log](https://download.openmmlab.com/mmocr/kie/sdmgr/20210917_050824.log.json) |
|
||||
|
||||
|
||||
:::{note}
|
||||
1. In the case of openset, the number of node categories is unknown or unfixed, and more node category can be added.
|
||||
2. To show that our method can handle openset problem, we modify the ground truth of `WildReceipt` to `WildReceiptOpenset`. The `nodes` are just classified into 4 classes: `background, key, value, others`, while adding `edge` labels for each box.
|
||||
3. The model is used to predict whether two nodes are a pair connecting by a valid edge.
|
||||
|
||||
:::
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
_base_ = ['../../_base_/default_runtime.py']
|
||||
|
||||
model = dict(
|
||||
type='SDMGR',
|
||||
backbone=dict(type='UNet', base_channels=16),
|
||||
bbox_head=dict(
|
||||
type='SDMGRHead', visual_dim=16, num_chars=92, num_classes=4),
|
||||
visual_modality=False,
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
class_list=None,
|
||||
openset=True)
|
||||
|
||||
optimizer = dict(type='Adam', weight_decay=0.0001)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
lr_config = dict(
|
||||
policy='step',
|
||||
warmup='linear',
|
||||
warmup_iters=1,
|
||||
warmup_ratio=1,
|
||||
step=[40, 50])
|
||||
total_epochs = 60
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='ResizeNoImg', img_scale=(1024, 512), keep_ratio=True),
|
||||
dict(type='KIEFormatBundle'),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img', 'relations', 'texts', 'gt_bboxes', 'gt_labels'],
|
||||
meta_keys=('filename', 'ori_filename', 'ori_texts'))
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='ResizeNoImg', img_scale=(1024, 512), keep_ratio=True),
|
||||
dict(type='KIEFormatBundle'),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img', 'relations', 'texts', 'gt_bboxes'],
|
||||
meta_keys=('filename', 'ori_filename', 'ori_texts', 'ori_boxes'))
|
||||
]
|
||||
|
||||
dataset_type = 'OpensetKIEDataset'
|
||||
data_root = 'data/wildreceipt'
|
||||
|
||||
loader = dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineJsonParser',
|
||||
keys=['file_name', 'height', 'width', 'annotations']))
|
||||
|
||||
train = dict(
|
||||
type=dataset_type,
|
||||
ann_file=f'{data_root}/openset_train.txt',
|
||||
pipeline=train_pipeline,
|
||||
img_prefix=data_root,
|
||||
link_type='one-to-many',
|
||||
loader=loader,
|
||||
dict_file=f'{data_root}/dict.txt',
|
||||
test_mode=False)
|
||||
test = dict(
|
||||
type=dataset_type,
|
||||
ann_file=f'{data_root}/openset_test.txt',
|
||||
pipeline=test_pipeline,
|
||||
img_prefix=data_root,
|
||||
link_type='one-to-many',
|
||||
loader=loader,
|
||||
dict_file=f'{data_root}/dict.txt',
|
||||
test_mode=True)
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=1,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=train,
|
||||
val=test,
|
||||
test=test)
|
||||
|
||||
evaluation = dict(interval=1, metric='openset_f1', metric_options=None)
|
||||
|
||||
find_unused_parameters = True
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import evaluation
|
||||
from .mask import extract_boundary, points2boundary, seg2boundary
|
||||
from .visualize import (det_recog_show_result, imshow_edge_node,
|
||||
from .visualize import (det_recog_show_result, imshow_edge, imshow_edge_node,
|
||||
imshow_pred_boundary, imshow_text_char_boundary,
|
||||
imshow_text_label, overlay_mask_img, show_feature,
|
||||
show_img_boundary, show_pred_gt)
|
||||
|
@ -12,6 +12,6 @@ __all__ = [
|
|||
'points2boundary', 'seg2boundary', 'extract_boundary', 'overlay_mask_img',
|
||||
'show_feature', 'show_img_boundary', 'show_pred_gt',
|
||||
'imshow_pred_boundary', 'imshow_text_char_boundary', 'imshow_text_label',
|
||||
'imshow_edge_node', 'det_recog_show_result'
|
||||
'imshow_edge_node', 'det_recog_show_result', 'imshow_edge'
|
||||
]
|
||||
__all__ += evaluation.__all__
|
||||
|
|
|
@ -543,7 +543,15 @@ def draw_texts(img, texts, boxes=None, draw_box=True, on_ori_img=False):
|
|||
return out_img
|
||||
|
||||
|
||||
def draw_texts_by_pil(img, texts, boxes=None, draw_box=True, on_ori_img=False):
|
||||
def draw_texts_by_pil(img,
|
||||
texts,
|
||||
boxes=None,
|
||||
draw_box=True,
|
||||
on_ori_img=False,
|
||||
font_size=None,
|
||||
fill_color=None,
|
||||
draw_pos=None,
|
||||
return_text_size=False):
|
||||
"""Draw boxes and texts on empty image, especially for Chinese.
|
||||
|
||||
Args:
|
||||
|
@ -552,9 +560,18 @@ def draw_texts_by_pil(img, texts, boxes=None, draw_box=True, on_ori_img=False):
|
|||
boxes (list[list[float]]): Detected bounding boxes.
|
||||
draw_box (bool): Whether draw box or not. If False, draw text only.
|
||||
on_ori_img (bool): If True, draw box and text on input image,
|
||||
else, on a new empty image.
|
||||
Return:
|
||||
out_img (np.ndarray): Visualized text image.
|
||||
else on a new empty image.
|
||||
font_size (int, optional): Size to create a font object for a font.
|
||||
fill_color (tuple(int), optional): Fill color for text.
|
||||
draw_pos (tuple(int), optional): Start point to draw text.
|
||||
return_text_size (bool): If True, return the list of text size.
|
||||
|
||||
Returns:
|
||||
(np.ndarray, list[tuple]) or np.ndarray: Return a tuple
|
||||
``(out_img, text_sizes)``, where ``out_img`` is the output image
|
||||
with texts drawn on it and ``text_sizes`` are the size of drawing
|
||||
texts. If ``return_text_size`` is False, only the output image will be
|
||||
returned.
|
||||
"""
|
||||
|
||||
color_list = gen_color()
|
||||
|
@ -568,6 +585,8 @@ def draw_texts_by_pil(img, texts, boxes=None, draw_box=True, on_ori_img=False):
|
|||
else:
|
||||
out_img = Image.new('RGB', (w, h), color=(255, 255, 255))
|
||||
out_draw = ImageDraw.Draw(out_img)
|
||||
|
||||
text_sizes = []
|
||||
for idx, (box, text) in enumerate(zip(boxes, texts)):
|
||||
if len(text) == 0:
|
||||
continue
|
||||
|
@ -576,8 +595,6 @@ def draw_texts_by_pil(img, texts, boxes=None, draw_box=True, on_ori_img=False):
|
|||
color = tuple(list(color_list[idx % len(color_list)])[::-1])
|
||||
if draw_box:
|
||||
out_draw.line(box, fill=color, width=1)
|
||||
box_width = max(max_x - min_x, max_y - min_y)
|
||||
font_size = int(0.9 * box_width / len(text))
|
||||
dirname, _ = os.path.split(os.path.abspath(__file__))
|
||||
font_path = os.path.join(dirname, 'font.TTF')
|
||||
if not os.path.exists(font_path):
|
||||
|
@ -585,13 +602,24 @@ def draw_texts_by_pil(img, texts, boxes=None, draw_box=True, on_ori_img=False):
|
|||
print(f'Downloading {url} ...')
|
||||
local_filename, _ = urllib.request.urlretrieve(url)
|
||||
shutil.move(local_filename, font_path)
|
||||
if font_size is None:
|
||||
box_width = max(max_x - min_x, max_y - min_y)
|
||||
font_size = int(0.9 * box_width / len(text))
|
||||
fnt = ImageFont.truetype(font_path, font_size)
|
||||
out_draw.text((min_x + 1, min_y + 1), text, font=fnt, fill=(0, 0, 0))
|
||||
if draw_pos is None:
|
||||
draw_pos = (min_x + 1, min_y + 1)
|
||||
if fill_color is None:
|
||||
fill_color = (0, 0, 0)
|
||||
out_draw.text(draw_pos, text, font=fnt, fill=fill_color)
|
||||
text_sizes.append(fnt.getsize(text))
|
||||
|
||||
del out_draw
|
||||
|
||||
out_img = cv2.cvtColor(np.asarray(out_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
if return_text_size:
|
||||
return out_img, text_sizes
|
||||
|
||||
return out_img
|
||||
|
||||
|
||||
|
@ -640,3 +668,202 @@ def det_recog_show_result(img, end2end_res, out_file=None):
|
|||
mmcv.imwrite(out_img, out_file)
|
||||
|
||||
return out_img
|
||||
|
||||
|
||||
def draw_edge_result(img, result, edge_thresh=0.5, keynode_thresh=0.5):
|
||||
"""Draw text and their relationship on empty images.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The original image.
|
||||
result (dict): The result of model forward_test, including:
|
||||
- img_metas (list[dict]): List of meta information dictionary.
|
||||
- nodes (Tensor): Node prediction with size:
|
||||
number_node * node_classes.
|
||||
- edges (Tensor): Edge prediction with size: number_edge * 2.
|
||||
edge_thresh (float): Score threshold for edge classification.
|
||||
keynode_thresh (float): Score threshold for node
|
||||
(``key``) classification.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image with key, value and relation drawn on it.
|
||||
"""
|
||||
|
||||
h, w = img.shape[:2]
|
||||
|
||||
vis_area_width = w // 3 * 2
|
||||
vis_area_height = h
|
||||
dist_key_to_value = vis_area_width // 2
|
||||
dist_pair_to_pair = 30
|
||||
|
||||
bbox_x1 = dist_pair_to_pair
|
||||
bbox_y1 = 0
|
||||
|
||||
new_w = vis_area_width
|
||||
new_h = vis_area_height
|
||||
pred_edge_img = np.ones((new_h, new_w, 3), dtype=np.uint8) * 255
|
||||
|
||||
nodes = result['nodes'].detach().cpu()
|
||||
texts = result['img_metas'][0]['ori_texts']
|
||||
num_nodes = result['nodes'].size(0)
|
||||
edges = result['edges'].detach().cpu()[:, -1].view(num_nodes, num_nodes)
|
||||
|
||||
# (i, j) will be a valid pair
|
||||
# either edge_score(node_i->node_j) > edge_thresh
|
||||
# or edge_score(node_j->node_i) > edge_thresh
|
||||
pairs = (torch.max(edges, edges.T) > edge_thresh).nonzero(as_tuple=True)
|
||||
|
||||
# 1. "for n1, n2 in zip(*pairs) if n1 < n2":
|
||||
# Only (n1, n2) will be included if n1 < n2 but not (n2, n1), to
|
||||
# avoid duplication.
|
||||
# 2. "(n1, n2) if nodes[n1, 1] > nodes[n1, 2]":
|
||||
# nodes[n1, 1] is the score that this node is predicted as key,
|
||||
# nodes[n1, 2] is the score that this node is predicted as value.
|
||||
# If nodes[n1, 1] > nodes[n1, 2], n1 will be the index of key,
|
||||
# so that n2 will be the index of value.
|
||||
result_pairs = [(n1, n2) if nodes[n1, 1] > nodes[n1, 2] else (n2, n1)
|
||||
for n1, n2 in zip(*pairs) if n1 < n2]
|
||||
|
||||
result_pairs.sort()
|
||||
key_current_idx = -1
|
||||
pos_current = (-1, -1)
|
||||
newline_flag = False
|
||||
|
||||
key_font_size = 15
|
||||
value_font_size = 15
|
||||
key_font_color = (0, 0, 0)
|
||||
value_font_color = (0, 0, 255)
|
||||
arrow_color = (0, 0, 255)
|
||||
for pair in result_pairs:
|
||||
key_idx = int(pair[0].item())
|
||||
if nodes[key_idx, 1] < keynode_thresh:
|
||||
continue
|
||||
if key_idx != key_current_idx:
|
||||
# move y-coords down for a new key
|
||||
bbox_y1 += 10
|
||||
# enlarge blank area to show key-value info
|
||||
if newline_flag:
|
||||
bbox_x1 += vis_area_width
|
||||
tmp_img = np.ones(
|
||||
(new_h, new_w + vis_area_width, 3), dtype=np.uint8) * 255
|
||||
tmp_img[:new_h, :new_w] = pred_edge_img
|
||||
pred_edge_img = tmp_img
|
||||
new_w += vis_area_width
|
||||
newline_flag = False
|
||||
bbox_y1 = 10
|
||||
key_text = texts[key_idx]
|
||||
key_pos = (bbox_x1, bbox_y1)
|
||||
value_idx = pair[1].item()
|
||||
value_text = texts[value_idx]
|
||||
value_pos = (bbox_x1 + dist_key_to_value, bbox_y1)
|
||||
if key_idx != key_current_idx:
|
||||
# draw text for a new key
|
||||
key_current_idx = key_idx
|
||||
pred_edge_img, text_sizes = draw_texts_by_pil(
|
||||
pred_edge_img, [key_text],
|
||||
draw_box=False,
|
||||
on_ori_img=True,
|
||||
font_size=key_font_size,
|
||||
font_color=key_font_color,
|
||||
draw_pos=key_pos,
|
||||
return_text_size=True)
|
||||
pos_right_bottom = (key_pos[0] + text_sizes[0][0],
|
||||
key_pos[1] + text_sizes[0][1])
|
||||
pos_current = (pos_right_bottom[0] + 5, bbox_y1 + 10)
|
||||
pred_edge_img = cv2.arrowedLine(
|
||||
pred_edge_img, (pos_right_bottom[0] + 5, bbox_y1 + 10),
|
||||
(bbox_x1 + dist_key_to_value - 5, bbox_y1 + 10), arrow_color,
|
||||
1)
|
||||
else:
|
||||
# draw arrow from key to value
|
||||
if newline_flag:
|
||||
tmp_img = np.ones((new_h + dist_pair_to_pair, new_w, 3),
|
||||
dtype=np.uint8) * 255
|
||||
tmp_img[:new_h, :new_w] = pred_edge_img
|
||||
pred_edge_img = tmp_img
|
||||
new_h += dist_pair_to_pair
|
||||
pred_edge_img = cv2.arrowedLine(pred_edge_img, pos_current,
|
||||
(bbox_x1 + dist_key_to_value - 5,
|
||||
bbox_y1 + 10), arrow_color, 1)
|
||||
# draw text for value
|
||||
pred_edge_img = draw_texts_by_pil(
|
||||
pred_edge_img, [value_text],
|
||||
draw_box=False,
|
||||
on_ori_img=True,
|
||||
font_size=value_font_size,
|
||||
font_color=value_font_color,
|
||||
draw_pos=value_pos,
|
||||
return_text_size=False)
|
||||
bbox_y1 += dist_pair_to_pair
|
||||
if bbox_y1 + dist_pair_to_pair >= new_h:
|
||||
newline_flag = True
|
||||
|
||||
return pred_edge_img
|
||||
|
||||
|
||||
def imshow_edge(img,
|
||||
result,
|
||||
boxes,
|
||||
show=False,
|
||||
win_name='',
|
||||
wait_time=-1,
|
||||
out_file=None):
|
||||
"""Display the prediction results of the nodes and edges of the KIE model.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The original image.
|
||||
result (dict): The result of model forward_test, including:
|
||||
- img_metas (list[dict]): List of meta information dictionary.
|
||||
- nodes (Tensor): Node prediction with size: \
|
||||
number_node * node_classes.
|
||||
- edges (Tensor): Edge prediction with size: number_edge * 2.
|
||||
boxes (list): The text boxes corresponding to the nodes.
|
||||
show (bool): Whether to show the image. Default: False.
|
||||
win_name (str): The window name. Default: ''
|
||||
wait_time (float): Value of waitKey param. Default: 0.
|
||||
out_file (str or None): The filename to write the image.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image with key, value and relation drawn on it.
|
||||
"""
|
||||
img = mmcv.imread(img)
|
||||
h, w = img.shape[:2]
|
||||
color_list = gen_color()
|
||||
|
||||
for i, box in enumerate(boxes):
|
||||
new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]],
|
||||
[box[0], box[3]]]
|
||||
Pts = np.array([new_box], np.int32)
|
||||
cv2.polylines(
|
||||
img, [Pts.reshape((-1, 1, 2))],
|
||||
True,
|
||||
color=color_list[i % len(color_list)],
|
||||
thickness=1)
|
||||
|
||||
pred_img_h = h
|
||||
pred_img_w = w
|
||||
|
||||
pred_edge_img = draw_edge_result(img, result)
|
||||
pred_img_h = max(pred_img_h, pred_edge_img.shape[0])
|
||||
pred_img_w += pred_edge_img.shape[1]
|
||||
|
||||
vis_img = np.zeros((pred_img_h, pred_img_w, 3), dtype=np.uint8)
|
||||
vis_img[:h, :w] = img
|
||||
vis_img[:, w:] = 255
|
||||
|
||||
height_t, width_t = pred_edge_img.shape[:2]
|
||||
vis_img[:height_t, w:(w + width_t)] = pred_edge_img
|
||||
|
||||
if show:
|
||||
mmcv.imshow(vis_img, win_name, wait_time)
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(vis_img, out_file)
|
||||
res_dic = {
|
||||
'boxes': boxes,
|
||||
'nodes': result['nodes'].detach().cpu(),
|
||||
'edges': result['edges'].detach().cpu(),
|
||||
'metas': result['img_metas'][0]
|
||||
}
|
||||
mmcv.dump(res_dic, f'{out_file}_res.pkl')
|
||||
|
||||
return vis_img
|
||||
|
|
|
@ -8,6 +8,7 @@ from .kie_dataset import KIEDataset
|
|||
from .ner_dataset import NerDataset
|
||||
from .ocr_dataset import OCRDataset
|
||||
from .ocr_seg_dataset import OCRSegDataset
|
||||
from .openset_kie_dataset import OpensetKIEDataset
|
||||
from .pipelines import CustomFormatBundle, DBNetTargets, FCENetTargets
|
||||
from .text_det_dataset import TextDetDataset
|
||||
from .uniform_concat_dataset import UniformConcatDataset
|
||||
|
@ -18,7 +19,7 @@ __all__ = [
|
|||
'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset',
|
||||
'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle',
|
||||
'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets',
|
||||
'NerDataset', 'UniformConcatDataset'
|
||||
'NerDataset', 'UniformConcatDataset', 'OpensetKIEDataset'
|
||||
]
|
||||
|
||||
__all__ += utils.__all__
|
||||
|
|
|
@ -0,0 +1,308 @@
|
|||
import copy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmdet.datasets.builder import DATASETS
|
||||
|
||||
from mmocr.datasets import KIEDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class OpensetKIEDataset(KIEDataset):
|
||||
"""Openset KIE classifies the nodes (i.e. text boxes) into bg/key/value
|
||||
categories, and additionally learns key-value relationship among nodes.
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path.
|
||||
loader (dict): Dictionary to construct loader
|
||||
to load annotation infos.
|
||||
dict_file (str): Character dict file path.
|
||||
img_prefix (str, optional): Image prefix to generate full
|
||||
image path.
|
||||
pipeline (list[dict]): Processing pipeline.
|
||||
norm (float): Norm to map value from one range to another.
|
||||
link_type (str): ``one-to-one`` | ``one-to-many`` |
|
||||
``many-to-one`` | ``many-to-many``. For ``many-to-many``,
|
||||
one key box can have many values and vice versa.
|
||||
edge_thr (float): Score threshold for a valid edge.
|
||||
test_mode (bool, optional): If True, try...except will
|
||||
be turned off in __getitem__.
|
||||
key_node_idx (int): Index of key in node classes.
|
||||
value_node_idx (int): Index of value in node classes.
|
||||
node_classes (int): Number of node classes.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ann_file,
|
||||
loader,
|
||||
dict_file,
|
||||
img_prefix='',
|
||||
pipeline=None,
|
||||
norm=10.,
|
||||
link_type='one-to-one',
|
||||
edge_thr=0.5,
|
||||
test_mode=True,
|
||||
key_node_idx=1,
|
||||
value_node_idx=2,
|
||||
node_classes=4):
|
||||
super().__init__(ann_file, loader, dict_file, img_prefix, pipeline,
|
||||
norm, False, test_mode)
|
||||
assert link_type in [
|
||||
'one-to-one', 'one-to-many', 'many-to-one', 'many-to-many', 'none'
|
||||
]
|
||||
self.link_type = link_type
|
||||
self.data_dict = {x['file_name']: x for x in self.data_infos}
|
||||
self.edge_thr = edge_thr
|
||||
self.key_node_idx = key_node_idx
|
||||
self.value_node_idx = value_node_idx
|
||||
self.node_classes = node_classes
|
||||
|
||||
def pre_pipeline(self, results):
|
||||
super().pre_pipeline(results)
|
||||
results['ori_texts'] = results['ann_info']['ori_texts']
|
||||
results['ori_boxes'] = results['ann_info']['ori_boxes']
|
||||
|
||||
def list_to_numpy(self, ann_infos):
|
||||
results = super().list_to_numpy(ann_infos)
|
||||
results.update(dict(ori_texts=ann_infos['texts']))
|
||||
results.update(dict(ori_boxes=ann_infos['boxes']))
|
||||
|
||||
return results
|
||||
|
||||
def evaluate(self,
|
||||
results,
|
||||
metric='openset_f1',
|
||||
metric_options=None,
|
||||
**kwargs):
|
||||
# Protect ``metric_options`` since it uses mutable value as default
|
||||
metric_options = copy.deepcopy(metric_options)
|
||||
|
||||
metrics = metric if isinstance(metric, list) else [metric]
|
||||
allowed_metrics = ['openset_f1']
|
||||
for m in metrics:
|
||||
if m not in allowed_metrics:
|
||||
raise KeyError(f'metric {m} is not supported')
|
||||
|
||||
preds, gts = [], []
|
||||
for result in results:
|
||||
# data for preds
|
||||
pred = self.decode_pred(result)
|
||||
preds.append(pred)
|
||||
# data for gts
|
||||
gt = self.decode_gt(pred['filename'])
|
||||
gts.append(gt)
|
||||
|
||||
return self.compute_openset_f1(preds, gts)
|
||||
|
||||
def _decode_pairs_gt(self, labels, edge_ids):
|
||||
"""Find all pairs in gt.
|
||||
|
||||
The first index in the pair (n1, n2) is key.
|
||||
"""
|
||||
gt_pairs = []
|
||||
for i, label in enumerate(labels):
|
||||
if label == self.key_node_idx:
|
||||
for j, edge_id in enumerate(edge_ids):
|
||||
if edge_id == edge_ids[i] and labels[
|
||||
j] == self.value_node_idx:
|
||||
gt_pairs.append((i, j))
|
||||
|
||||
return gt_pairs
|
||||
|
||||
@staticmethod
|
||||
def _decode_pairs_pred(nodes,
|
||||
labels,
|
||||
edges,
|
||||
edge_thr=0.5,
|
||||
link_type='one-to-one'):
|
||||
"""Find all pairs in prediction.
|
||||
|
||||
The first index in the pair (n1, n2) is more likely to be a key
|
||||
according to prediction in nodes.
|
||||
"""
|
||||
edges = torch.max(edges, edges.T)
|
||||
if link_type in ['none', 'many-to-many']:
|
||||
pair_inds = (edges > edge_thr).nonzero(as_tuple=True)
|
||||
pred_pairs = [(n1.item(),
|
||||
n2.item()) if nodes[n1, 1] > nodes[n1, 2] else
|
||||
(n2.item(), n1.item()) for n1, n2 in zip(*pair_inds)
|
||||
if n1 < n2]
|
||||
pred_pairs = [(i, j) for i, j in pred_pairs
|
||||
if labels[i] == 1 and labels[j] == 2]
|
||||
else:
|
||||
links = edges.clone()
|
||||
links[links <= edge_thr] = -1
|
||||
links[labels != 1, :] = -1
|
||||
links[:, labels != 2] = -1
|
||||
|
||||
pred_pairs = []
|
||||
while (links > -1).any():
|
||||
i, j = np.unravel_index(torch.argmax(links), links.shape)
|
||||
pred_pairs.append((i, j))
|
||||
if link_type == 'one-to-one':
|
||||
links[i, :] = -1
|
||||
links[:, j] = -1
|
||||
elif link_type == 'one-to-many':
|
||||
links[:, j] = -1
|
||||
elif link_type == 'many-to-one':
|
||||
links[i, :] = -1
|
||||
else:
|
||||
raise ValueError(f'not supported link type {link_type}')
|
||||
|
||||
pairs_conf = [edges[i, j].item() for i, j in pred_pairs]
|
||||
return pred_pairs, pairs_conf
|
||||
|
||||
def decode_pred(self, result):
|
||||
"""Decode prediction.
|
||||
|
||||
Assemble boxes and predicted labels into bboxes, and convert edges into
|
||||
matrix.
|
||||
"""
|
||||
filename = result['img_metas'][0]['ori_filename']
|
||||
nodes = result['nodes'].cpu()
|
||||
labels_conf, labels = torch.max(nodes, dim=-1)
|
||||
num_nodes = nodes.size(0)
|
||||
edges = result['edges'][:, -1].view(num_nodes, num_nodes).cpu()
|
||||
annos = self.data_dict[filename]['annotations']
|
||||
boxes = [x['box'] for x in annos]
|
||||
texts = [x['text'] for x in annos]
|
||||
bboxes = torch.Tensor(boxes)[:, [0, 1, 4, 5]]
|
||||
bboxes = torch.cat([bboxes, labels[:, None].float()], -1)
|
||||
pairs, pairs_conf = self._decode_pairs_pred(nodes, labels, edges,
|
||||
self.edge_thr,
|
||||
self.link_type)
|
||||
pred = {
|
||||
'filename': filename,
|
||||
'boxes': boxes,
|
||||
'bboxes': bboxes.tolist(),
|
||||
'labels': labels.tolist(),
|
||||
'labels_conf': labels_conf.tolist(),
|
||||
'texts': texts,
|
||||
'pairs': pairs,
|
||||
'pairs_conf': pairs_conf
|
||||
}
|
||||
return pred
|
||||
|
||||
def decode_gt(self, filename):
|
||||
"""Decode ground truth.
|
||||
|
||||
Assemble boxes and labels into bboxes.
|
||||
"""
|
||||
annos = self.data_dict[filename]['annotations']
|
||||
labels = torch.Tensor([x['label'] for x in annos])
|
||||
texts = [x['text'] for x in annos]
|
||||
edge_ids = [x['edge'] for x in annos]
|
||||
boxes = [x['box'] for x in annos]
|
||||
bboxes = torch.Tensor(boxes)[:, [0, 1, 4, 5]]
|
||||
bboxes = torch.cat([bboxes, labels[:, None].float()], -1)
|
||||
pairs = self._decode_pairs_gt(labels, edge_ids)
|
||||
gt = {
|
||||
'filename': filename,
|
||||
'boxes': boxes,
|
||||
'bboxes': bboxes.tolist(),
|
||||
'labels': labels.tolist(),
|
||||
'labels_conf': [1. for _ in labels],
|
||||
'texts': texts,
|
||||
'pairs': pairs,
|
||||
'pairs_conf': [1. for _ in pairs]
|
||||
}
|
||||
return gt
|
||||
|
||||
def compute_openset_f1(self, preds, gts):
|
||||
"""Compute openset macro-f1 and micro-f1 score.
|
||||
|
||||
Args:
|
||||
preds: (list[dict]): List of prediction results, including
|
||||
keys: ``filename``, ``pairs``, etc.
|
||||
gts: (list[dict]): List of ground-truth infos, including
|
||||
keys: ``filename``, ``pairs``, etc.
|
||||
|
||||
Returns:
|
||||
dict: Evaluation result with keys: ``node_openset_micro_f1``, \
|
||||
``node_openset_macro_f1``, ``edge_openset_f1``.
|
||||
"""
|
||||
|
||||
total_edge_hit_num, total_edge_gt_num, total_edge_pred_num = 0, 0, 0
|
||||
total_node_hit_num, total_node_gt_num, total_node_pred_num = {}, {}, {}
|
||||
node_inds = list(range(self.node_classes))
|
||||
for node_idx in node_inds:
|
||||
total_node_hit_num[node_idx] = 0
|
||||
total_node_gt_num[node_idx] = 0
|
||||
total_node_pred_num[node_idx] = 0
|
||||
|
||||
img_level_res = {}
|
||||
for pred, gt in zip(preds, gts):
|
||||
filename = pred['filename']
|
||||
img_res = {}
|
||||
# edge metric related
|
||||
pairs_pred = pred['pairs']
|
||||
pairs_gt = gt['pairs']
|
||||
img_res['edge_hit_num'] = 0
|
||||
for pair in pairs_gt:
|
||||
if pair in pairs_pred:
|
||||
img_res['edge_hit_num'] += 1
|
||||
img_res['edge_recall'] = 1.0 * img_res['edge_hit_num'] / max(
|
||||
1, len(pairs_gt))
|
||||
img_res['edge_precision'] = 1.0 * img_res['edge_hit_num'] / max(
|
||||
1, len(pairs_pred))
|
||||
img_res['f1'] = 2 * img_res['edge_recall'] * img_res[
|
||||
'edge_precision'] / max(
|
||||
1, img_res['edge_recall'] + img_res['edge_precision'])
|
||||
total_edge_hit_num += img_res['edge_hit_num']
|
||||
total_edge_gt_num += len(pairs_gt)
|
||||
total_edge_pred_num += len(pairs_pred)
|
||||
|
||||
# node metric related
|
||||
nodes_pred = pred['labels']
|
||||
nodes_gt = gt['labels']
|
||||
for i, node_gt in enumerate(nodes_gt):
|
||||
node_gt = int(node_gt)
|
||||
total_node_gt_num[node_gt] += 1
|
||||
if nodes_pred[i] == node_gt:
|
||||
total_node_hit_num[node_gt] += 1
|
||||
for node_pred in nodes_pred:
|
||||
total_node_pred_num[node_pred] += 1
|
||||
|
||||
img_level_res[filename] = img_res
|
||||
|
||||
stats = {}
|
||||
# edge f1
|
||||
total_edge_recall = 1.0 * total_edge_hit_num / max(
|
||||
1, total_edge_gt_num)
|
||||
total_edge_precision = 1.0 * total_edge_hit_num / max(
|
||||
1, total_edge_pred_num)
|
||||
edge_f1 = 2 * total_edge_recall * total_edge_precision / max(
|
||||
1, total_edge_recall + total_edge_precision)
|
||||
stats = {'edge_openset_f1': edge_f1}
|
||||
|
||||
# node f1
|
||||
cared_node_hit_num, cared_node_gt_num, cared_node_pred_num = 0, 0, 0
|
||||
node_macro_metric = {}
|
||||
for node_idx in node_inds:
|
||||
if node_idx < 1 or node_idx > 2:
|
||||
continue
|
||||
cared_node_hit_num += total_node_hit_num[node_idx]
|
||||
cared_node_gt_num += total_node_gt_num[node_idx]
|
||||
cared_node_pred_num += total_node_pred_num[node_idx]
|
||||
node_res = {}
|
||||
node_res['recall'] = 1.0 * total_node_hit_num[node_idx] / max(
|
||||
1, total_node_gt_num[node_idx])
|
||||
node_res['precision'] = 1.0 * total_node_hit_num[node_idx] / max(
|
||||
1, total_node_pred_num[node_idx])
|
||||
node_res[
|
||||
'f1'] = 2 * node_res['recall'] * node_res['precision'] / max(
|
||||
1, node_res['recall'] + node_res['precision'])
|
||||
node_macro_metric[node_idx] = node_res
|
||||
|
||||
node_micro_recall = 1.0 * cared_node_hit_num / max(
|
||||
1, cared_node_gt_num)
|
||||
node_micro_precision = 1.0 * cared_node_hit_num / max(
|
||||
1, cared_node_pred_num)
|
||||
node_micro_f1 = 2 * node_micro_recall * node_micro_precision / max(
|
||||
1, node_micro_recall + node_micro_precision)
|
||||
|
||||
stats['node_openset_micro_f1'] = node_micro_f1
|
||||
stats['node_openset_macro_f1'] = np.mean(
|
||||
[v['f1'] for k, v in node_macro_metric.items()])
|
||||
|
||||
return stats
|
|
@ -6,7 +6,7 @@ from mmdet.core import bbox2roi
|
|||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mmocr.core import imshow_edge_node
|
||||
from mmocr.core import imshow_edge, imshow_edge_node
|
||||
from mmocr.models.builder import DETECTORS, build_roi_extractor
|
||||
from mmocr.models.common.detectors import SingleStageDetector
|
||||
from mmocr.utils import list_from_file
|
||||
|
@ -36,7 +36,8 @@ class SDMGR(SingleStageDetector):
|
|||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
class_list=None,
|
||||
init_cfg=None):
|
||||
init_cfg=None,
|
||||
openset=False):
|
||||
super().__init__(
|
||||
backbone, neck, bbox_head, train_cfg, test_cfg, init_cfg=init_cfg)
|
||||
self.visual_modality = visual_modality
|
||||
|
@ -49,6 +50,7 @@ class SDMGR(SingleStageDetector):
|
|||
else:
|
||||
self.extractor = None
|
||||
self.class_list = class_list
|
||||
self.openset = openset
|
||||
|
||||
def forward_train(self, img, img_metas, relations, texts, gt_bboxes,
|
||||
gt_labels):
|
||||
|
@ -136,15 +138,25 @@ class SDMGR(SingleStageDetector):
|
|||
if out_file is not None:
|
||||
show = False
|
||||
|
||||
img = imshow_edge_node(
|
||||
img,
|
||||
result,
|
||||
boxes,
|
||||
idx_to_cls=idx_to_cls,
|
||||
show=show,
|
||||
win_name=win_name,
|
||||
wait_time=wait_time,
|
||||
out_file=out_file)
|
||||
if self.openset:
|
||||
img = imshow_edge(
|
||||
img,
|
||||
result,
|
||||
boxes,
|
||||
show=show,
|
||||
win_name=win_name,
|
||||
wait_time=wait_time,
|
||||
out_file=out_file)
|
||||
else:
|
||||
img = imshow_edge_node(
|
||||
img,
|
||||
result,
|
||||
boxes,
|
||||
idx_to_cls=idx_to_cls,
|
||||
show=show,
|
||||
win_name=win_name,
|
||||
wait_time=wait_time,
|
||||
out_file=out_file)
|
||||
|
||||
if not (show or out_file):
|
||||
warnings.warn('show==False and out_file is not specified, only '
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import math
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
|
||||
from mmocr.datasets.openset_kie_dataset import OpensetKIEDataset
|
||||
from mmocr.utils import list_to_file
|
||||
|
||||
|
||||
def _create_dummy_ann_file(ann_file):
|
||||
ann_info1 = {
|
||||
'file_name':
|
||||
'1.png',
|
||||
'height':
|
||||
200,
|
||||
'width':
|
||||
200,
|
||||
'annotations': [{
|
||||
'text': 'store',
|
||||
'box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0],
|
||||
'label': 1,
|
||||
'edge': 1
|
||||
}, {
|
||||
'text': 'MyFamily',
|
||||
'box': [23.0, 2.0, 31.0, 1.0, 24.0, 11.0, 16.0, 11.0],
|
||||
'label': 2,
|
||||
'edge': 1
|
||||
}]
|
||||
}
|
||||
list_to_file(ann_file, [json.dumps(ann_info1)])
|
||||
|
||||
return ann_info1
|
||||
|
||||
|
||||
def _create_dummy_dict_file(dict_file):
|
||||
dict_str = '0123'
|
||||
list_to_file(dict_file, list(dict_str))
|
||||
|
||||
|
||||
def _create_dummy_loader():
|
||||
loader = dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineJsonParser',
|
||||
keys=['file_name', 'height', 'width', 'annotations']))
|
||||
return loader
|
||||
|
||||
|
||||
def test_openset_kie_dataset():
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
# create dummy data
|
||||
ann_file = osp.join(tmp_dir_name, 'fake_data.txt')
|
||||
ann_info1 = _create_dummy_ann_file(ann_file)
|
||||
|
||||
dict_file = osp.join(tmp_dir_name, 'fake_dict.txt')
|
||||
_create_dummy_dict_file(dict_file)
|
||||
|
||||
# test initialization
|
||||
loader = _create_dummy_loader()
|
||||
dataset = OpensetKIEDataset(ann_file, loader, dict_file, pipeline=[])
|
||||
|
||||
dataset.prepare_train_img(0)
|
||||
|
||||
# test pre_pipeline
|
||||
img_ann_info = dataset.data_infos[0]
|
||||
img_info = {
|
||||
'filename': img_ann_info['file_name'],
|
||||
'height': img_ann_info['height'],
|
||||
'width': img_ann_info['width']
|
||||
}
|
||||
ann_info = dataset._parse_anno_info(img_ann_info['annotations'])
|
||||
results = dict(img_info=img_info, ann_info=ann_info)
|
||||
dataset.pre_pipeline(results)
|
||||
assert results['img_prefix'] == dataset.img_prefix
|
||||
assert 'ori_texts' in results
|
||||
|
||||
# test evaluation
|
||||
result = {
|
||||
'img_metas': [{
|
||||
'filename': ann_info1['file_name'],
|
||||
'ori_filename': ann_info1['file_name'],
|
||||
'ori_texts': [],
|
||||
'ori_boxes': []
|
||||
}]
|
||||
}
|
||||
for anno in ann_info1['annotations']:
|
||||
result['img_metas'][0]['ori_texts'].append(anno['text'])
|
||||
result['img_metas'][0]['ori_boxes'].append(anno['box'])
|
||||
result['nodes'] = torch.tensor([[0.01, 0.8, 0.01, 0.18],
|
||||
[0.01, 0.01, 0.9, 0.08]])
|
||||
result['edges'] = torch.Tensor([[0.01, 0.99] for _ in range(4)])
|
||||
|
||||
eval_res = dataset.evaluate([result])
|
||||
assert math.isclose(eval_res['edge_openset_f1'], 1.0, abs_tol=1e-4)
|
|
@ -0,0 +1,121 @@
|
|||
import argparse
|
||||
import json
|
||||
from functools import partial
|
||||
|
||||
import mmcv
|
||||
|
||||
from mmocr.utils import list_from_file, list_to_file
|
||||
|
||||
|
||||
def convert(closeset_line, merge_bg_others=False, ignore_idx=0, others_idx=25):
|
||||
"""Convert line-json str of closeset to line-json str of openset. Note that
|
||||
this function is designed for closeset-wildreceipt to openset-wildreceipt.
|
||||
It may not be suitable to your own dataset.
|
||||
|
||||
Args:
|
||||
closeset_line (str): The string to be deserialized to
|
||||
the closeset dictionary object.
|
||||
merge_bg_others (bool): If True, give the same label to "background"
|
||||
class and "others" class.
|
||||
ignore_idx (int): Index for ``ignore`` class.
|
||||
others_idx (int): Index for ``others`` class.
|
||||
"""
|
||||
# Two labels at the same index of the following two lists
|
||||
# make up a key-value pair. For example, in wildreceipt,
|
||||
# closeset_key_inds[0] maps to "Store_name_key"
|
||||
# and closeset_value_inds[0] maps to "Store_addr_value".
|
||||
closeset_key_inds = list(range(2, others_idx, 2))
|
||||
closeset_value_inds = list(range(1, others_idx, 2))
|
||||
|
||||
openset_node_label_mapping = {'bg': 0, 'key': 1, 'value': 2, 'others': 3}
|
||||
if merge_bg_others:
|
||||
openset_node_label_mapping['others'] = openset_node_label_mapping['bg']
|
||||
|
||||
closeset_obj = json.loads(closeset_line)
|
||||
openset_obj = {
|
||||
'file_name': closeset_obj['file_name'],
|
||||
'height': closeset_obj['height'],
|
||||
'width': closeset_obj['width'],
|
||||
'annotations': []
|
||||
}
|
||||
|
||||
edge_idx = 1
|
||||
label_to_edge = {}
|
||||
for anno in closeset_obj['annotations']:
|
||||
label = anno['label']
|
||||
if label == ignore_idx:
|
||||
anno['label'] = openset_node_label_mapping['bg']
|
||||
anno['edge'] = edge_idx
|
||||
edge_idx += 1
|
||||
elif label == others_idx:
|
||||
anno['label'] = openset_node_label_mapping['others']
|
||||
anno['edge'] = edge_idx
|
||||
edge_idx += 1
|
||||
else:
|
||||
edge = label_to_edge.get(label, None)
|
||||
if edge is not None:
|
||||
anno['edge'] = edge
|
||||
if label in closeset_key_inds:
|
||||
anno['label'] = openset_node_label_mapping['key']
|
||||
elif label in closeset_value_inds:
|
||||
anno['label'] = openset_node_label_mapping['value']
|
||||
else:
|
||||
tmp_key = 'key'
|
||||
if label in closeset_key_inds:
|
||||
label_with_same_edge = closeset_value_inds[
|
||||
closeset_key_inds.index(label)]
|
||||
elif label in closeset_value_inds:
|
||||
label_with_same_edge = closeset_key_inds[
|
||||
closeset_value_inds.index(label)]
|
||||
tmp_key = 'value'
|
||||
edge_counterpart = label_to_edge.get(label_with_same_edge,
|
||||
None)
|
||||
if edge_counterpart is not None:
|
||||
anno['edge'] = edge_counterpart
|
||||
else:
|
||||
anno['edge'] = edge_idx
|
||||
edge_idx += 1
|
||||
anno['label'] = openset_node_label_mapping[tmp_key]
|
||||
label_to_edge[label] = anno['edge']
|
||||
|
||||
openset_obj['annotations'] = closeset_obj['annotations']
|
||||
|
||||
return json.dumps(openset_obj, ensure_ascii=False)
|
||||
|
||||
|
||||
def process(closeset_file, openset_file, merge_bg_others=False, n_proc=10):
|
||||
closeset_lines = list_from_file(closeset_file)
|
||||
|
||||
convert_func = partial(convert, merge_bg_others=merge_bg_others)
|
||||
|
||||
openset_lines = mmcv.track_parallel_progress(
|
||||
convert_func, closeset_lines, nproc=n_proc)
|
||||
|
||||
list_to_file(openset_file, openset_lines)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('in_file', help='Annotation file for closeset.')
|
||||
parser.add_argument('out_file', help='Annotation file for openset.')
|
||||
parser.add_argument(
|
||||
'--merge',
|
||||
action='store_true',
|
||||
help='Merge two classes: "background" and "others" in closeset '
|
||||
'to one class in openset.')
|
||||
parser.add_argument(
|
||||
'--n_proc', type=int, default=10, help='Number of process.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
process(args.in_file, args.out_file, args.merge, args.n_proc)
|
||||
|
||||
print('finish')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue