[Feature] Support openset kie (#498)

* add openset kie dataset

* updare readme

* add anno convert script

* update docstring

* update script

* add & update docstring

* fix typo

* update docstring format
pull/574/head
Hongbin Sun 2021-11-11 14:47:38 +08:00 committed by GitHub
parent 9f42d78db7
commit a50b0c9fb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 885 additions and 21 deletions

View File

@ -23,3 +23,17 @@
| :--------------------------------------------------------------------: | :--------------: | :------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [sdmgr_unet16](/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py) | Visual + Textual | 0.888 | [model](https://download.openmmlab.com/mmocr/kie/sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth) \| [log](https://download.openmmlab.com/mmocr/kie/sdmgr/20210520_132236.log.json) |
| [sdmgr_novisual](/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py) | Textual | 0.870 | [model](https://download.openmmlab.com/mmocr/kie/sdmgr/sdmgr_novisual_60e_wildreceipt_20210517-a44850da.pth) \| [log](https://download.openmmlab.com/mmocr/kie/sdmgr/20210517_205829.log.json) |
### WildReceiptOpenset
| Method | Modality | Edge F1-Score | Node Macro F1-Score | Node Micro F1-Score | Download |
| :-------: | :----------: | :--------: | :--------: | :--------: | :--------: |
| [sdmgr_novisual](/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt_openset.py) | Textual | 0.786 | 0.926 | 0.935 | [model](https://download.openmmlab.com/mmocr/kie/sdmgr/sdmgr_novisual_60e_wildreceipt_openset_20210917-d236b3ea.pth) \| [log](https://download.openmmlab.com/mmocr/kie/sdmgr/20210917_050824.log.json) |
:::{note}
1. In the case of openset, the number of node categories is unknown or unfixed, and more node category can be added.
2. To show that our method can handle openset problem, we modify the ground truth of `WildReceipt` to `WildReceiptOpenset`. The `nodes` are just classified into 4 classes: `background, key, value, others`, while adding `edge` labels for each box.
3. The model is used to predict whether two nodes are a pair connecting by a valid edge.
:::

View File

@ -0,0 +1,83 @@
_base_ = ['../../_base_/default_runtime.py']
model = dict(
type='SDMGR',
backbone=dict(type='UNet', base_channels=16),
bbox_head=dict(
type='SDMGRHead', visual_dim=16, num_chars=92, num_classes=4),
visual_modality=False,
train_cfg=None,
test_cfg=None,
class_list=None,
openset=True)
optimizer = dict(type='Adam', weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=1,
warmup_ratio=1,
step=[40, 50])
total_epochs = 60
train_pipeline = [
dict(type='LoadAnnotations'),
dict(type='ResizeNoImg', img_scale=(1024, 512), keep_ratio=True),
dict(type='KIEFormatBundle'),
dict(
type='Collect',
keys=['img', 'relations', 'texts', 'gt_bboxes', 'gt_labels'],
meta_keys=('filename', 'ori_filename', 'ori_texts'))
]
test_pipeline = [
dict(type='LoadAnnotations'),
dict(type='ResizeNoImg', img_scale=(1024, 512), keep_ratio=True),
dict(type='KIEFormatBundle'),
dict(
type='Collect',
keys=['img', 'relations', 'texts', 'gt_bboxes'],
meta_keys=('filename', 'ori_filename', 'ori_texts', 'ori_boxes'))
]
dataset_type = 'OpensetKIEDataset'
data_root = 'data/wildreceipt'
loader = dict(
type='HardDiskLoader',
repeat=1,
parser=dict(
type='LineJsonParser',
keys=['file_name', 'height', 'width', 'annotations']))
train = dict(
type=dataset_type,
ann_file=f'{data_root}/openset_train.txt',
pipeline=train_pipeline,
img_prefix=data_root,
link_type='one-to-many',
loader=loader,
dict_file=f'{data_root}/dict.txt',
test_mode=False)
test = dict(
type=dataset_type,
ann_file=f'{data_root}/openset_test.txt',
pipeline=test_pipeline,
img_prefix=data_root,
link_type='one-to-many',
loader=loader,
dict_file=f'{data_root}/dict.txt',
test_mode=True)
data = dict(
samples_per_gpu=4,
workers_per_gpu=1,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=train,
val=test,
test=test)
evaluation = dict(interval=1, metric='openset_f1', metric_options=None)
find_unused_parameters = True

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import evaluation
from .mask import extract_boundary, points2boundary, seg2boundary
from .visualize import (det_recog_show_result, imshow_edge_node,
from .visualize import (det_recog_show_result, imshow_edge, imshow_edge_node,
imshow_pred_boundary, imshow_text_char_boundary,
imshow_text_label, overlay_mask_img, show_feature,
show_img_boundary, show_pred_gt)
@ -12,6 +12,6 @@ __all__ = [
'points2boundary', 'seg2boundary', 'extract_boundary', 'overlay_mask_img',
'show_feature', 'show_img_boundary', 'show_pred_gt',
'imshow_pred_boundary', 'imshow_text_char_boundary', 'imshow_text_label',
'imshow_edge_node', 'det_recog_show_result'
'imshow_edge_node', 'det_recog_show_result', 'imshow_edge'
]
__all__ += evaluation.__all__

View File

@ -543,7 +543,15 @@ def draw_texts(img, texts, boxes=None, draw_box=True, on_ori_img=False):
return out_img
def draw_texts_by_pil(img, texts, boxes=None, draw_box=True, on_ori_img=False):
def draw_texts_by_pil(img,
texts,
boxes=None,
draw_box=True,
on_ori_img=False,
font_size=None,
fill_color=None,
draw_pos=None,
return_text_size=False):
"""Draw boxes and texts on empty image, especially for Chinese.
Args:
@ -552,9 +560,18 @@ def draw_texts_by_pil(img, texts, boxes=None, draw_box=True, on_ori_img=False):
boxes (list[list[float]]): Detected bounding boxes.
draw_box (bool): Whether draw box or not. If False, draw text only.
on_ori_img (bool): If True, draw box and text on input image,
else, on a new empty image.
Return:
out_img (np.ndarray): Visualized text image.
else on a new empty image.
font_size (int, optional): Size to create a font object for a font.
fill_color (tuple(int), optional): Fill color for text.
draw_pos (tuple(int), optional): Start point to draw text.
return_text_size (bool): If True, return the list of text size.
Returns:
(np.ndarray, list[tuple]) or np.ndarray: Return a tuple
``(out_img, text_sizes)``, where ``out_img`` is the output image
with texts drawn on it and ``text_sizes`` are the size of drawing
texts. If ``return_text_size`` is False, only the output image will be
returned.
"""
color_list = gen_color()
@ -568,6 +585,8 @@ def draw_texts_by_pil(img, texts, boxes=None, draw_box=True, on_ori_img=False):
else:
out_img = Image.new('RGB', (w, h), color=(255, 255, 255))
out_draw = ImageDraw.Draw(out_img)
text_sizes = []
for idx, (box, text) in enumerate(zip(boxes, texts)):
if len(text) == 0:
continue
@ -576,8 +595,6 @@ def draw_texts_by_pil(img, texts, boxes=None, draw_box=True, on_ori_img=False):
color = tuple(list(color_list[idx % len(color_list)])[::-1])
if draw_box:
out_draw.line(box, fill=color, width=1)
box_width = max(max_x - min_x, max_y - min_y)
font_size = int(0.9 * box_width / len(text))
dirname, _ = os.path.split(os.path.abspath(__file__))
font_path = os.path.join(dirname, 'font.TTF')
if not os.path.exists(font_path):
@ -585,13 +602,24 @@ def draw_texts_by_pil(img, texts, boxes=None, draw_box=True, on_ori_img=False):
print(f'Downloading {url} ...')
local_filename, _ = urllib.request.urlretrieve(url)
shutil.move(local_filename, font_path)
if font_size is None:
box_width = max(max_x - min_x, max_y - min_y)
font_size = int(0.9 * box_width / len(text))
fnt = ImageFont.truetype(font_path, font_size)
out_draw.text((min_x + 1, min_y + 1), text, font=fnt, fill=(0, 0, 0))
if draw_pos is None:
draw_pos = (min_x + 1, min_y + 1)
if fill_color is None:
fill_color = (0, 0, 0)
out_draw.text(draw_pos, text, font=fnt, fill=fill_color)
text_sizes.append(fnt.getsize(text))
del out_draw
out_img = cv2.cvtColor(np.asarray(out_img), cv2.COLOR_RGB2BGR)
if return_text_size:
return out_img, text_sizes
return out_img
@ -640,3 +668,202 @@ def det_recog_show_result(img, end2end_res, out_file=None):
mmcv.imwrite(out_img, out_file)
return out_img
def draw_edge_result(img, result, edge_thresh=0.5, keynode_thresh=0.5):
"""Draw text and their relationship on empty images.
Args:
img (np.ndarray): The original image.
result (dict): The result of model forward_test, including:
- img_metas (list[dict]): List of meta information dictionary.
- nodes (Tensor): Node prediction with size:
number_node * node_classes.
- edges (Tensor): Edge prediction with size: number_edge * 2.
edge_thresh (float): Score threshold for edge classification.
keynode_thresh (float): Score threshold for node
(``key``) classification.
Returns:
np.ndarray: The image with key, value and relation drawn on it.
"""
h, w = img.shape[:2]
vis_area_width = w // 3 * 2
vis_area_height = h
dist_key_to_value = vis_area_width // 2
dist_pair_to_pair = 30
bbox_x1 = dist_pair_to_pair
bbox_y1 = 0
new_w = vis_area_width
new_h = vis_area_height
pred_edge_img = np.ones((new_h, new_w, 3), dtype=np.uint8) * 255
nodes = result['nodes'].detach().cpu()
texts = result['img_metas'][0]['ori_texts']
num_nodes = result['nodes'].size(0)
edges = result['edges'].detach().cpu()[:, -1].view(num_nodes, num_nodes)
# (i, j) will be a valid pair
# either edge_score(node_i->node_j) > edge_thresh
# or edge_score(node_j->node_i) > edge_thresh
pairs = (torch.max(edges, edges.T) > edge_thresh).nonzero(as_tuple=True)
# 1. "for n1, n2 in zip(*pairs) if n1 < n2":
# Only (n1, n2) will be included if n1 < n2 but not (n2, n1), to
# avoid duplication.
# 2. "(n1, n2) if nodes[n1, 1] > nodes[n1, 2]":
# nodes[n1, 1] is the score that this node is predicted as key,
# nodes[n1, 2] is the score that this node is predicted as value.
# If nodes[n1, 1] > nodes[n1, 2], n1 will be the index of key,
# so that n2 will be the index of value.
result_pairs = [(n1, n2) if nodes[n1, 1] > nodes[n1, 2] else (n2, n1)
for n1, n2 in zip(*pairs) if n1 < n2]
result_pairs.sort()
key_current_idx = -1
pos_current = (-1, -1)
newline_flag = False
key_font_size = 15
value_font_size = 15
key_font_color = (0, 0, 0)
value_font_color = (0, 0, 255)
arrow_color = (0, 0, 255)
for pair in result_pairs:
key_idx = int(pair[0].item())
if nodes[key_idx, 1] < keynode_thresh:
continue
if key_idx != key_current_idx:
# move y-coords down for a new key
bbox_y1 += 10
# enlarge blank area to show key-value info
if newline_flag:
bbox_x1 += vis_area_width
tmp_img = np.ones(
(new_h, new_w + vis_area_width, 3), dtype=np.uint8) * 255
tmp_img[:new_h, :new_w] = pred_edge_img
pred_edge_img = tmp_img
new_w += vis_area_width
newline_flag = False
bbox_y1 = 10
key_text = texts[key_idx]
key_pos = (bbox_x1, bbox_y1)
value_idx = pair[1].item()
value_text = texts[value_idx]
value_pos = (bbox_x1 + dist_key_to_value, bbox_y1)
if key_idx != key_current_idx:
# draw text for a new key
key_current_idx = key_idx
pred_edge_img, text_sizes = draw_texts_by_pil(
pred_edge_img, [key_text],
draw_box=False,
on_ori_img=True,
font_size=key_font_size,
font_color=key_font_color,
draw_pos=key_pos,
return_text_size=True)
pos_right_bottom = (key_pos[0] + text_sizes[0][0],
key_pos[1] + text_sizes[0][1])
pos_current = (pos_right_bottom[0] + 5, bbox_y1 + 10)
pred_edge_img = cv2.arrowedLine(
pred_edge_img, (pos_right_bottom[0] + 5, bbox_y1 + 10),
(bbox_x1 + dist_key_to_value - 5, bbox_y1 + 10), arrow_color,
1)
else:
# draw arrow from key to value
if newline_flag:
tmp_img = np.ones((new_h + dist_pair_to_pair, new_w, 3),
dtype=np.uint8) * 255
tmp_img[:new_h, :new_w] = pred_edge_img
pred_edge_img = tmp_img
new_h += dist_pair_to_pair
pred_edge_img = cv2.arrowedLine(pred_edge_img, pos_current,
(bbox_x1 + dist_key_to_value - 5,
bbox_y1 + 10), arrow_color, 1)
# draw text for value
pred_edge_img = draw_texts_by_pil(
pred_edge_img, [value_text],
draw_box=False,
on_ori_img=True,
font_size=value_font_size,
font_color=value_font_color,
draw_pos=value_pos,
return_text_size=False)
bbox_y1 += dist_pair_to_pair
if bbox_y1 + dist_pair_to_pair >= new_h:
newline_flag = True
return pred_edge_img
def imshow_edge(img,
result,
boxes,
show=False,
win_name='',
wait_time=-1,
out_file=None):
"""Display the prediction results of the nodes and edges of the KIE model.
Args:
img (np.ndarray): The original image.
result (dict): The result of model forward_test, including:
- img_metas (list[dict]): List of meta information dictionary.
- nodes (Tensor): Node prediction with size: \
number_node * node_classes.
- edges (Tensor): Edge prediction with size: number_edge * 2.
boxes (list): The text boxes corresponding to the nodes.
show (bool): Whether to show the image. Default: False.
win_name (str): The window name. Default: ''
wait_time (float): Value of waitKey param. Default: 0.
out_file (str or None): The filename to write the image.
Default: None.
Returns:
np.ndarray: The image with key, value and relation drawn on it.
"""
img = mmcv.imread(img)
h, w = img.shape[:2]
color_list = gen_color()
for i, box in enumerate(boxes):
new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]],
[box[0], box[3]]]
Pts = np.array([new_box], np.int32)
cv2.polylines(
img, [Pts.reshape((-1, 1, 2))],
True,
color=color_list[i % len(color_list)],
thickness=1)
pred_img_h = h
pred_img_w = w
pred_edge_img = draw_edge_result(img, result)
pred_img_h = max(pred_img_h, pred_edge_img.shape[0])
pred_img_w += pred_edge_img.shape[1]
vis_img = np.zeros((pred_img_h, pred_img_w, 3), dtype=np.uint8)
vis_img[:h, :w] = img
vis_img[:, w:] = 255
height_t, width_t = pred_edge_img.shape[:2]
vis_img[:height_t, w:(w + width_t)] = pred_edge_img
if show:
mmcv.imshow(vis_img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(vis_img, out_file)
res_dic = {
'boxes': boxes,
'nodes': result['nodes'].detach().cpu(),
'edges': result['edges'].detach().cpu(),
'metas': result['img_metas'][0]
}
mmcv.dump(res_dic, f'{out_file}_res.pkl')
return vis_img

View File

@ -8,6 +8,7 @@ from .kie_dataset import KIEDataset
from .ner_dataset import NerDataset
from .ocr_dataset import OCRDataset
from .ocr_seg_dataset import OCRSegDataset
from .openset_kie_dataset import OpensetKIEDataset
from .pipelines import CustomFormatBundle, DBNetTargets, FCENetTargets
from .text_det_dataset import TextDetDataset
from .uniform_concat_dataset import UniformConcatDataset
@ -18,7 +19,7 @@ __all__ = [
'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset',
'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle',
'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets',
'NerDataset', 'UniformConcatDataset'
'NerDataset', 'UniformConcatDataset', 'OpensetKIEDataset'
]
__all__ += utils.__all__

View File

@ -0,0 +1,308 @@
import copy
import numpy as np
import torch
from mmdet.datasets.builder import DATASETS
from mmocr.datasets import KIEDataset
@DATASETS.register_module()
class OpensetKIEDataset(KIEDataset):
"""Openset KIE classifies the nodes (i.e. text boxes) into bg/key/value
categories, and additionally learns key-value relationship among nodes.
Args:
ann_file (str): Annotation file path.
loader (dict): Dictionary to construct loader
to load annotation infos.
dict_file (str): Character dict file path.
img_prefix (str, optional): Image prefix to generate full
image path.
pipeline (list[dict]): Processing pipeline.
norm (float): Norm to map value from one range to another.
link_type (str): ``one-to-one`` | ``one-to-many`` |
``many-to-one`` | ``many-to-many``. For ``many-to-many``,
one key box can have many values and vice versa.
edge_thr (float): Score threshold for a valid edge.
test_mode (bool, optional): If True, try...except will
be turned off in __getitem__.
key_node_idx (int): Index of key in node classes.
value_node_idx (int): Index of value in node classes.
node_classes (int): Number of node classes.
"""
def __init__(self,
ann_file,
loader,
dict_file,
img_prefix='',
pipeline=None,
norm=10.,
link_type='one-to-one',
edge_thr=0.5,
test_mode=True,
key_node_idx=1,
value_node_idx=2,
node_classes=4):
super().__init__(ann_file, loader, dict_file, img_prefix, pipeline,
norm, False, test_mode)
assert link_type in [
'one-to-one', 'one-to-many', 'many-to-one', 'many-to-many', 'none'
]
self.link_type = link_type
self.data_dict = {x['file_name']: x for x in self.data_infos}
self.edge_thr = edge_thr
self.key_node_idx = key_node_idx
self.value_node_idx = value_node_idx
self.node_classes = node_classes
def pre_pipeline(self, results):
super().pre_pipeline(results)
results['ori_texts'] = results['ann_info']['ori_texts']
results['ori_boxes'] = results['ann_info']['ori_boxes']
def list_to_numpy(self, ann_infos):
results = super().list_to_numpy(ann_infos)
results.update(dict(ori_texts=ann_infos['texts']))
results.update(dict(ori_boxes=ann_infos['boxes']))
return results
def evaluate(self,
results,
metric='openset_f1',
metric_options=None,
**kwargs):
# Protect ``metric_options`` since it uses mutable value as default
metric_options = copy.deepcopy(metric_options)
metrics = metric if isinstance(metric, list) else [metric]
allowed_metrics = ['openset_f1']
for m in metrics:
if m not in allowed_metrics:
raise KeyError(f'metric {m} is not supported')
preds, gts = [], []
for result in results:
# data for preds
pred = self.decode_pred(result)
preds.append(pred)
# data for gts
gt = self.decode_gt(pred['filename'])
gts.append(gt)
return self.compute_openset_f1(preds, gts)
def _decode_pairs_gt(self, labels, edge_ids):
"""Find all pairs in gt.
The first index in the pair (n1, n2) is key.
"""
gt_pairs = []
for i, label in enumerate(labels):
if label == self.key_node_idx:
for j, edge_id in enumerate(edge_ids):
if edge_id == edge_ids[i] and labels[
j] == self.value_node_idx:
gt_pairs.append((i, j))
return gt_pairs
@staticmethod
def _decode_pairs_pred(nodes,
labels,
edges,
edge_thr=0.5,
link_type='one-to-one'):
"""Find all pairs in prediction.
The first index in the pair (n1, n2) is more likely to be a key
according to prediction in nodes.
"""
edges = torch.max(edges, edges.T)
if link_type in ['none', 'many-to-many']:
pair_inds = (edges > edge_thr).nonzero(as_tuple=True)
pred_pairs = [(n1.item(),
n2.item()) if nodes[n1, 1] > nodes[n1, 2] else
(n2.item(), n1.item()) for n1, n2 in zip(*pair_inds)
if n1 < n2]
pred_pairs = [(i, j) for i, j in pred_pairs
if labels[i] == 1 and labels[j] == 2]
else:
links = edges.clone()
links[links <= edge_thr] = -1
links[labels != 1, :] = -1
links[:, labels != 2] = -1
pred_pairs = []
while (links > -1).any():
i, j = np.unravel_index(torch.argmax(links), links.shape)
pred_pairs.append((i, j))
if link_type == 'one-to-one':
links[i, :] = -1
links[:, j] = -1
elif link_type == 'one-to-many':
links[:, j] = -1
elif link_type == 'many-to-one':
links[i, :] = -1
else:
raise ValueError(f'not supported link type {link_type}')
pairs_conf = [edges[i, j].item() for i, j in pred_pairs]
return pred_pairs, pairs_conf
def decode_pred(self, result):
"""Decode prediction.
Assemble boxes and predicted labels into bboxes, and convert edges into
matrix.
"""
filename = result['img_metas'][0]['ori_filename']
nodes = result['nodes'].cpu()
labels_conf, labels = torch.max(nodes, dim=-1)
num_nodes = nodes.size(0)
edges = result['edges'][:, -1].view(num_nodes, num_nodes).cpu()
annos = self.data_dict[filename]['annotations']
boxes = [x['box'] for x in annos]
texts = [x['text'] for x in annos]
bboxes = torch.Tensor(boxes)[:, [0, 1, 4, 5]]
bboxes = torch.cat([bboxes, labels[:, None].float()], -1)
pairs, pairs_conf = self._decode_pairs_pred(nodes, labels, edges,
self.edge_thr,
self.link_type)
pred = {
'filename': filename,
'boxes': boxes,
'bboxes': bboxes.tolist(),
'labels': labels.tolist(),
'labels_conf': labels_conf.tolist(),
'texts': texts,
'pairs': pairs,
'pairs_conf': pairs_conf
}
return pred
def decode_gt(self, filename):
"""Decode ground truth.
Assemble boxes and labels into bboxes.
"""
annos = self.data_dict[filename]['annotations']
labels = torch.Tensor([x['label'] for x in annos])
texts = [x['text'] for x in annos]
edge_ids = [x['edge'] for x in annos]
boxes = [x['box'] for x in annos]
bboxes = torch.Tensor(boxes)[:, [0, 1, 4, 5]]
bboxes = torch.cat([bboxes, labels[:, None].float()], -1)
pairs = self._decode_pairs_gt(labels, edge_ids)
gt = {
'filename': filename,
'boxes': boxes,
'bboxes': bboxes.tolist(),
'labels': labels.tolist(),
'labels_conf': [1. for _ in labels],
'texts': texts,
'pairs': pairs,
'pairs_conf': [1. for _ in pairs]
}
return gt
def compute_openset_f1(self, preds, gts):
"""Compute openset macro-f1 and micro-f1 score.
Args:
preds: (list[dict]): List of prediction results, including
keys: ``filename``, ``pairs``, etc.
gts: (list[dict]): List of ground-truth infos, including
keys: ``filename``, ``pairs``, etc.
Returns:
dict: Evaluation result with keys: ``node_openset_micro_f1``, \
``node_openset_macro_f1``, ``edge_openset_f1``.
"""
total_edge_hit_num, total_edge_gt_num, total_edge_pred_num = 0, 0, 0
total_node_hit_num, total_node_gt_num, total_node_pred_num = {}, {}, {}
node_inds = list(range(self.node_classes))
for node_idx in node_inds:
total_node_hit_num[node_idx] = 0
total_node_gt_num[node_idx] = 0
total_node_pred_num[node_idx] = 0
img_level_res = {}
for pred, gt in zip(preds, gts):
filename = pred['filename']
img_res = {}
# edge metric related
pairs_pred = pred['pairs']
pairs_gt = gt['pairs']
img_res['edge_hit_num'] = 0
for pair in pairs_gt:
if pair in pairs_pred:
img_res['edge_hit_num'] += 1
img_res['edge_recall'] = 1.0 * img_res['edge_hit_num'] / max(
1, len(pairs_gt))
img_res['edge_precision'] = 1.0 * img_res['edge_hit_num'] / max(
1, len(pairs_pred))
img_res['f1'] = 2 * img_res['edge_recall'] * img_res[
'edge_precision'] / max(
1, img_res['edge_recall'] + img_res['edge_precision'])
total_edge_hit_num += img_res['edge_hit_num']
total_edge_gt_num += len(pairs_gt)
total_edge_pred_num += len(pairs_pred)
# node metric related
nodes_pred = pred['labels']
nodes_gt = gt['labels']
for i, node_gt in enumerate(nodes_gt):
node_gt = int(node_gt)
total_node_gt_num[node_gt] += 1
if nodes_pred[i] == node_gt:
total_node_hit_num[node_gt] += 1
for node_pred in nodes_pred:
total_node_pred_num[node_pred] += 1
img_level_res[filename] = img_res
stats = {}
# edge f1
total_edge_recall = 1.0 * total_edge_hit_num / max(
1, total_edge_gt_num)
total_edge_precision = 1.0 * total_edge_hit_num / max(
1, total_edge_pred_num)
edge_f1 = 2 * total_edge_recall * total_edge_precision / max(
1, total_edge_recall + total_edge_precision)
stats = {'edge_openset_f1': edge_f1}
# node f1
cared_node_hit_num, cared_node_gt_num, cared_node_pred_num = 0, 0, 0
node_macro_metric = {}
for node_idx in node_inds:
if node_idx < 1 or node_idx > 2:
continue
cared_node_hit_num += total_node_hit_num[node_idx]
cared_node_gt_num += total_node_gt_num[node_idx]
cared_node_pred_num += total_node_pred_num[node_idx]
node_res = {}
node_res['recall'] = 1.0 * total_node_hit_num[node_idx] / max(
1, total_node_gt_num[node_idx])
node_res['precision'] = 1.0 * total_node_hit_num[node_idx] / max(
1, total_node_pred_num[node_idx])
node_res[
'f1'] = 2 * node_res['recall'] * node_res['precision'] / max(
1, node_res['recall'] + node_res['precision'])
node_macro_metric[node_idx] = node_res
node_micro_recall = 1.0 * cared_node_hit_num / max(
1, cared_node_gt_num)
node_micro_precision = 1.0 * cared_node_hit_num / max(
1, cared_node_pred_num)
node_micro_f1 = 2 * node_micro_recall * node_micro_precision / max(
1, node_micro_recall + node_micro_precision)
stats['node_openset_micro_f1'] = node_micro_f1
stats['node_openset_macro_f1'] = np.mean(
[v['f1'] for k, v in node_macro_metric.items()])
return stats

View File

@ -6,7 +6,7 @@ from mmdet.core import bbox2roi
from torch import nn
from torch.nn import functional as F
from mmocr.core import imshow_edge_node
from mmocr.core import imshow_edge, imshow_edge_node
from mmocr.models.builder import DETECTORS, build_roi_extractor
from mmocr.models.common.detectors import SingleStageDetector
from mmocr.utils import list_from_file
@ -36,7 +36,8 @@ class SDMGR(SingleStageDetector):
train_cfg=None,
test_cfg=None,
class_list=None,
init_cfg=None):
init_cfg=None,
openset=False):
super().__init__(
backbone, neck, bbox_head, train_cfg, test_cfg, init_cfg=init_cfg)
self.visual_modality = visual_modality
@ -49,6 +50,7 @@ class SDMGR(SingleStageDetector):
else:
self.extractor = None
self.class_list = class_list
self.openset = openset
def forward_train(self, img, img_metas, relations, texts, gt_bboxes,
gt_labels):
@ -136,15 +138,25 @@ class SDMGR(SingleStageDetector):
if out_file is not None:
show = False
img = imshow_edge_node(
img,
result,
boxes,
idx_to_cls=idx_to_cls,
show=show,
win_name=win_name,
wait_time=wait_time,
out_file=out_file)
if self.openset:
img = imshow_edge(
img,
result,
boxes,
show=show,
win_name=win_name,
wait_time=wait_time,
out_file=out_file)
else:
img = imshow_edge_node(
img,
result,
boxes,
idx_to_cls=idx_to_cls,
show=show,
win_name=win_name,
wait_time=wait_time,
out_file=out_file)
if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '

View File

@ -0,0 +1,98 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import math
import os.path as osp
import tempfile
import torch
from mmocr.datasets.openset_kie_dataset import OpensetKIEDataset
from mmocr.utils import list_to_file
def _create_dummy_ann_file(ann_file):
ann_info1 = {
'file_name':
'1.png',
'height':
200,
'width':
200,
'annotations': [{
'text': 'store',
'box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0],
'label': 1,
'edge': 1
}, {
'text': 'MyFamily',
'box': [23.0, 2.0, 31.0, 1.0, 24.0, 11.0, 16.0, 11.0],
'label': 2,
'edge': 1
}]
}
list_to_file(ann_file, [json.dumps(ann_info1)])
return ann_info1
def _create_dummy_dict_file(dict_file):
dict_str = '0123'
list_to_file(dict_file, list(dict_str))
def _create_dummy_loader():
loader = dict(
type='HardDiskLoader',
repeat=1,
parser=dict(
type='LineJsonParser',
keys=['file_name', 'height', 'width', 'annotations']))
return loader
def test_openset_kie_dataset():
with tempfile.TemporaryDirectory() as tmp_dir_name:
# create dummy data
ann_file = osp.join(tmp_dir_name, 'fake_data.txt')
ann_info1 = _create_dummy_ann_file(ann_file)
dict_file = osp.join(tmp_dir_name, 'fake_dict.txt')
_create_dummy_dict_file(dict_file)
# test initialization
loader = _create_dummy_loader()
dataset = OpensetKIEDataset(ann_file, loader, dict_file, pipeline=[])
dataset.prepare_train_img(0)
# test pre_pipeline
img_ann_info = dataset.data_infos[0]
img_info = {
'filename': img_ann_info['file_name'],
'height': img_ann_info['height'],
'width': img_ann_info['width']
}
ann_info = dataset._parse_anno_info(img_ann_info['annotations'])
results = dict(img_info=img_info, ann_info=ann_info)
dataset.pre_pipeline(results)
assert results['img_prefix'] == dataset.img_prefix
assert 'ori_texts' in results
# test evaluation
result = {
'img_metas': [{
'filename': ann_info1['file_name'],
'ori_filename': ann_info1['file_name'],
'ori_texts': [],
'ori_boxes': []
}]
}
for anno in ann_info1['annotations']:
result['img_metas'][0]['ori_texts'].append(anno['text'])
result['img_metas'][0]['ori_boxes'].append(anno['box'])
result['nodes'] = torch.tensor([[0.01, 0.8, 0.01, 0.18],
[0.01, 0.01, 0.9, 0.08]])
result['edges'] = torch.Tensor([[0.01, 0.99] for _ in range(4)])
eval_res = dataset.evaluate([result])
assert math.isclose(eval_res['edge_openset_f1'], 1.0, abs_tol=1e-4)

View File

@ -0,0 +1,121 @@
import argparse
import json
from functools import partial
import mmcv
from mmocr.utils import list_from_file, list_to_file
def convert(closeset_line, merge_bg_others=False, ignore_idx=0, others_idx=25):
"""Convert line-json str of closeset to line-json str of openset. Note that
this function is designed for closeset-wildreceipt to openset-wildreceipt.
It may not be suitable to your own dataset.
Args:
closeset_line (str): The string to be deserialized to
the closeset dictionary object.
merge_bg_others (bool): If True, give the same label to "background"
class and "others" class.
ignore_idx (int): Index for ``ignore`` class.
others_idx (int): Index for ``others`` class.
"""
# Two labels at the same index of the following two lists
# make up a key-value pair. For example, in wildreceipt,
# closeset_key_inds[0] maps to "Store_name_key"
# and closeset_value_inds[0] maps to "Store_addr_value".
closeset_key_inds = list(range(2, others_idx, 2))
closeset_value_inds = list(range(1, others_idx, 2))
openset_node_label_mapping = {'bg': 0, 'key': 1, 'value': 2, 'others': 3}
if merge_bg_others:
openset_node_label_mapping['others'] = openset_node_label_mapping['bg']
closeset_obj = json.loads(closeset_line)
openset_obj = {
'file_name': closeset_obj['file_name'],
'height': closeset_obj['height'],
'width': closeset_obj['width'],
'annotations': []
}
edge_idx = 1
label_to_edge = {}
for anno in closeset_obj['annotations']:
label = anno['label']
if label == ignore_idx:
anno['label'] = openset_node_label_mapping['bg']
anno['edge'] = edge_idx
edge_idx += 1
elif label == others_idx:
anno['label'] = openset_node_label_mapping['others']
anno['edge'] = edge_idx
edge_idx += 1
else:
edge = label_to_edge.get(label, None)
if edge is not None:
anno['edge'] = edge
if label in closeset_key_inds:
anno['label'] = openset_node_label_mapping['key']
elif label in closeset_value_inds:
anno['label'] = openset_node_label_mapping['value']
else:
tmp_key = 'key'
if label in closeset_key_inds:
label_with_same_edge = closeset_value_inds[
closeset_key_inds.index(label)]
elif label in closeset_value_inds:
label_with_same_edge = closeset_key_inds[
closeset_value_inds.index(label)]
tmp_key = 'value'
edge_counterpart = label_to_edge.get(label_with_same_edge,
None)
if edge_counterpart is not None:
anno['edge'] = edge_counterpart
else:
anno['edge'] = edge_idx
edge_idx += 1
anno['label'] = openset_node_label_mapping[tmp_key]
label_to_edge[label] = anno['edge']
openset_obj['annotations'] = closeset_obj['annotations']
return json.dumps(openset_obj, ensure_ascii=False)
def process(closeset_file, openset_file, merge_bg_others=False, n_proc=10):
closeset_lines = list_from_file(closeset_file)
convert_func = partial(convert, merge_bg_others=merge_bg_others)
openset_lines = mmcv.track_parallel_progress(
convert_func, closeset_lines, nproc=n_proc)
list_to_file(openset_file, openset_lines)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('in_file', help='Annotation file for closeset.')
parser.add_argument('out_file', help='Annotation file for openset.')
parser.add_argument(
'--merge',
action='store_true',
help='Merge two classes: "background" and "others" in closeset '
'to one class in openset.')
parser.add_argument(
'--n_proc', type=int, default=10, help='Number of process.')
args = parser.parse_args()
return args
def main():
args = parse_args()
process(args.in_file, args.out_file, args.merge, args.n_proc)
print('finish')
if __name__ == '__main__':
main()