mirror of https://github.com/open-mmlab/mmocr.git
fix #21: refactor kie dataset & add show_results
parent
9d62bdf84c
commit
9af0bed144
|
@ -1,5 +1,3 @@
|
|||
dataset_type = 'KIEDataset'
|
||||
data_root = 'data/wildreceipt'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
max_scale, min_scale = 1024, 512
|
||||
|
@ -27,33 +25,35 @@ test_pipeline = [
|
|||
dict(type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes'])
|
||||
]
|
||||
|
||||
vocab_file = 'dict.txt'
|
||||
class_file = 'class_list.txt'
|
||||
dataset_type = 'KIEDataset'
|
||||
data_root = 'data/wildreceipt/'
|
||||
|
||||
loader = dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineJsonParser',
|
||||
keys=['file_name', 'height', 'width', 'annotations']))
|
||||
|
||||
train = dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + 'train.txt',
|
||||
pipeline=train_pipeline,
|
||||
img_prefix=data_root,
|
||||
loader=loader,
|
||||
dict_file=data_root + 'dict.txt',
|
||||
test_mode=False)
|
||||
test = dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + 'test.txt',
|
||||
pipeline=test_pipeline,
|
||||
img_prefix=data_root,
|
||||
loader=loader,
|
||||
dict_file=data_root + 'dict.txt',
|
||||
test_mode=True)
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=0,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file='train.txt',
|
||||
pipeline=train_pipeline,
|
||||
data_root=data_root,
|
||||
vocab_file=vocab_file,
|
||||
class_file=class_file),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
ann_file='test.txt',
|
||||
pipeline=test_pipeline,
|
||||
data_root=data_root,
|
||||
vocab_file=vocab_file,
|
||||
class_file=class_file),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
ann_file='test.txt',
|
||||
pipeline=test_pipeline,
|
||||
data_root=data_root,
|
||||
vocab_file=vocab_file,
|
||||
class_file=class_file))
|
||||
samples_per_gpu=4, workers_per_gpu=0, train=train, val=test, test=test)
|
||||
|
||||
evaluation = dict(
|
||||
interval=1,
|
||||
|
@ -69,7 +69,8 @@ model = dict(
|
|||
type='SDMGRHead', visual_dim=16, num_chars=92, num_classes=26),
|
||||
visual_modality=False,
|
||||
train_cfg=None,
|
||||
test_cfg=None)
|
||||
test_cfg=None,
|
||||
class_list=data_root + 'class_list.txt')
|
||||
|
||||
optimizer = dict(type='Adam', weight_decay=0.0001)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
|
@ -82,16 +83,7 @@ lr_config = dict(
|
|||
total_epochs = 60
|
||||
|
||||
checkpoint_config = dict(interval=1)
|
||||
log_config = dict(
|
||||
interval=50,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
# dict(
|
||||
# type='PaviLoggerHook',
|
||||
# add_last_ckpt=True,
|
||||
# interval=5,
|
||||
# init_kwargs=dict(project='kie')),
|
||||
])
|
||||
log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
|
||||
dist_params = dict(backend='nccl')
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
dataset_type = 'KIEDataset'
|
||||
data_root = 'data/wildreceipt'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
max_scale, min_scale = 1024, 512
|
||||
|
@ -27,33 +25,35 @@ test_pipeline = [
|
|||
dict(type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes'])
|
||||
]
|
||||
|
||||
vocab_file = 'dict.txt'
|
||||
class_file = 'class_list.txt'
|
||||
dataset_type = 'KIEDataset'
|
||||
data_root = 'data/wildreceipt/'
|
||||
|
||||
loader = dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineJsonParser',
|
||||
keys=['file_name', 'height', 'width', 'annotations']))
|
||||
|
||||
train = dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + 'train.txt',
|
||||
pipeline=train_pipeline,
|
||||
img_prefix=data_root,
|
||||
loader=loader,
|
||||
dict_file=data_root + 'dict.txt',
|
||||
test_mode=False)
|
||||
test = dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + 'test.txt',
|
||||
pipeline=test_pipeline,
|
||||
img_prefix=data_root,
|
||||
loader=loader,
|
||||
dict_file=data_root + 'dict.txt',
|
||||
test_mode=True)
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=0,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file='train.txt',
|
||||
pipeline=train_pipeline,
|
||||
data_root=data_root,
|
||||
vocab_file=vocab_file,
|
||||
class_file=class_file),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
ann_file='test.txt',
|
||||
pipeline=test_pipeline,
|
||||
data_root=data_root,
|
||||
vocab_file=vocab_file,
|
||||
class_file=class_file),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
ann_file='test.txt',
|
||||
pipeline=test_pipeline,
|
||||
data_root=data_root,
|
||||
vocab_file=vocab_file,
|
||||
class_file=class_file))
|
||||
samples_per_gpu=4, workers_per_gpu=0, train=train, val=test, test=test)
|
||||
|
||||
evaluation = dict(
|
||||
interval=1,
|
||||
|
@ -69,7 +69,8 @@ model = dict(
|
|||
type='SDMGRHead', visual_dim=16, num_chars=92, num_classes=26),
|
||||
visual_modality=True,
|
||||
train_cfg=None,
|
||||
test_cfg=None)
|
||||
test_cfg=None,
|
||||
class_list=data_root + 'class_list.txt')
|
||||
|
||||
optimizer = dict(type='Adam', weight_decay=0.0001)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
|
@ -82,16 +83,7 @@ lr_config = dict(
|
|||
total_epochs = 60
|
||||
|
||||
checkpoint_config = dict(interval=1)
|
||||
log_config = dict(
|
||||
interval=50,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
# dict(
|
||||
# type='PaviLoggerHook',
|
||||
# add_last_ckpt=True,
|
||||
# interval=5,
|
||||
# init_kwargs=dict(project='kie')),
|
||||
])
|
||||
log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
|
||||
dist_params = dict(backend='nccl')
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
|
|
|
@ -4,6 +4,7 @@ import warnings
|
|||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
import mmocr.utils as utils
|
||||
|
@ -367,3 +368,52 @@ def imshow_text_label(img,
|
|||
mmcv.imwrite(img, out_file)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def imshow_edge_node(img,
|
||||
result,
|
||||
boxes,
|
||||
idx_to_cls={},
|
||||
show=False,
|
||||
win_name='',
|
||||
wait_time=-1,
|
||||
out_file=None):
|
||||
|
||||
img = mmcv.imread(img)
|
||||
h, w = img.shape[:2]
|
||||
|
||||
pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255
|
||||
max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1)
|
||||
node_pred_label = max_idx.numpy().tolist()
|
||||
node_pred_score = max_value.numpy().tolist()
|
||||
|
||||
for i, box in enumerate(boxes):
|
||||
new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]],
|
||||
[box[0], box[3]]]
|
||||
Pts = np.array([new_box], np.int32)
|
||||
cv2.polylines(
|
||||
img, [Pts.reshape((-1, 1, 2))],
|
||||
True,
|
||||
color=(255, 255, 0),
|
||||
thickness=1)
|
||||
x_min = int(min([point[0] for point in new_box]))
|
||||
y_min = int(min([point[1] for point in new_box]))
|
||||
|
||||
pred_label = str(node_pred_label[i])
|
||||
if pred_label in idx_to_cls:
|
||||
pred_label = idx_to_cls[pred_label]
|
||||
pred_score = '{:.2f}'.format(node_pred_score[i])
|
||||
text = pred_label + '(' + pred_score + ')'
|
||||
cv2.putText(pred_img, text, (x_min * 2, y_min),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
|
||||
|
||||
vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
|
||||
vis_img[:, :w] = img
|
||||
vis_img[:, w:] = pred_img
|
||||
|
||||
if show:
|
||||
mmcv.imshow(vis_img, win_name, wait_time)
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(vis_img, out_file)
|
||||
|
||||
return vis_img
|
||||
|
|
|
@ -1,261 +1,137 @@
|
|||
import copy
|
||||
from os import path as osp
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from PIL import Image
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmdet.datasets.builder import DATASETS
|
||||
from mmdet.datasets.custom import CustomDataset
|
||||
from mmocr.core import compute_f1_score
|
||||
from mmocr.datasets.base_dataset import BaseDataset
|
||||
from mmocr.datasets.pipelines.crop import sort_vertex
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class KIEDataset(CustomDataset):
|
||||
class KIEDataset(BaseDataset):
|
||||
"""
|
||||
Args:
|
||||
ann_file (str): Annotation file path.
|
||||
pipeline (list[dict]): Processing pipeline.
|
||||
loader (dict): Dictionary to construct loader
|
||||
to load annotation infos.
|
||||
img_prefix (str, optional): Image prefix to generate full
|
||||
image path.
|
||||
test_mode (bool, optional): If True, try...except will
|
||||
be turned off in __getitem__.
|
||||
dict_file (str): Character dict file path.
|
||||
norm (float): Norm to map value from one range to another.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ann_file,
|
||||
pipeline=None,
|
||||
data_root=None,
|
||||
loader,
|
||||
dict_file,
|
||||
img_prefix='',
|
||||
ann_prefix='',
|
||||
vocab_file=None,
|
||||
class_file=None,
|
||||
pipeline=None,
|
||||
norm=10.,
|
||||
thresholds=dict(edge=0.5),
|
||||
directed=False,
|
||||
test_mode=True,
|
||||
**kwargs):
|
||||
self.ann_prefix = ann_prefix
|
||||
self.norm = norm
|
||||
self.thresholds = thresholds
|
||||
self.directed = directed
|
||||
|
||||
if data_root is not None:
|
||||
if not osp.isabs(ann_file):
|
||||
self.ann_file = osp.join(data_root, ann_file)
|
||||
if not (ann_prefix is None or osp.isabs(ann_prefix)):
|
||||
self.ann_prefix = osp.join(data_root, ann_prefix)
|
||||
|
||||
self.vocab = dict({'': 0})
|
||||
vocab_file = osp.join(data_root, vocab_file)
|
||||
if osp.exists(vocab_file):
|
||||
with open(vocab_file, 'r') as fid:
|
||||
for idx, char in enumerate(fid.readlines(), 1):
|
||||
self.vocab[char.strip('\n')] = idx
|
||||
else:
|
||||
self.construct_dict(self.ann_file)
|
||||
with open(vocab_file, 'w') as fid:
|
||||
for key in self.vocab:
|
||||
if key:
|
||||
fid.write('{}\n'.format(key))
|
||||
|
||||
super().__init__(
|
||||
ann_file,
|
||||
loader,
|
||||
pipeline,
|
||||
data_root=data_root,
|
||||
img_prefix=img_prefix,
|
||||
**kwargs)
|
||||
test_mode=test_mode)
|
||||
assert osp.exists(dict_file)
|
||||
|
||||
self.idx_to_cls = dict()
|
||||
with open(osp.join(data_root, class_file), 'r') as fid:
|
||||
for line in fid.readlines():
|
||||
idx, cls = line.split()
|
||||
self.idx_to_cls[int(idx)] = cls
|
||||
self.norm = norm
|
||||
self.directed = directed
|
||||
|
||||
@staticmethod
|
||||
def _split_edge(line):
|
||||
text = ','.join(line[8:-1])
|
||||
if ';' in text and text.split(';')[0].isdecimal():
|
||||
edge, text = text.split(';', 1)
|
||||
edge = int(edge)
|
||||
else:
|
||||
edge = 0
|
||||
return edge, text
|
||||
self.dict = dict({'': 0})
|
||||
with open(dict_file, 'r') as fr:
|
||||
idx = 1
|
||||
for line in fr:
|
||||
char = line.strip()
|
||||
self.dict[char] = idx
|
||||
idx += 1
|
||||
|
||||
def construct_dict(self, ann_file):
|
||||
img_infos = mmcv.list_from_file(ann_file)
|
||||
for img_info in img_infos:
|
||||
_, annname = img_info.split()
|
||||
if self.ann_prefix:
|
||||
annname = osp.join(self.ann_prefix, annname)
|
||||
with open(annname, 'r') as fid:
|
||||
lines = fid.readlines()
|
||||
def pre_pipeline(self, results):
|
||||
results['img_prefix'] = self.img_prefix
|
||||
results['bbox_fields'] = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip().split(',')
|
||||
_, text = self._split_edge(line)
|
||||
for c in text:
|
||||
if c not in self.vocab:
|
||||
self.vocab[c] = len(self.vocab)
|
||||
self.vocab = dict(
|
||||
{k: idx
|
||||
for idx, k in enumerate(sorted(self.vocab.keys()))})
|
||||
def _parse_anno_info(self, annotations):
|
||||
"""Parse annotations of boxes, texts and labels for one image.
|
||||
Args:
|
||||
annotations (list[dict]): Annotations of one image, where
|
||||
each dict is for one character.
|
||||
|
||||
def convert_text(self, text):
|
||||
return [self.vocab[c] for c in text if c in self.vocab]
|
||||
Returns:
|
||||
dict: A dict containing the following keys:
|
||||
|
||||
def parse_lines(self, annname):
|
||||
boxes, edges, texts, chars, labels = [], [], [], [], []
|
||||
- bboxes (np.ndarray): Bbox in one image with shape:
|
||||
box_num * 4.
|
||||
- relations (np.ndarray): Relations between bbox with shape:
|
||||
box_num * box_num * D.
|
||||
- texts (np.ndarray): Text index with shape:
|
||||
box_num * text_max_len.
|
||||
- labels (np.ndarray): Box Labels with shape:
|
||||
box_num * (box_num + 1).
|
||||
"""
|
||||
|
||||
if self.ann_prefix:
|
||||
annname = osp.join(self.ann_prefix, annname)
|
||||
assert utils.is_type_list(annotations, dict)
|
||||
assert 'box' in annotations[0]
|
||||
assert 'text' in annotations[0]
|
||||
assert 'label' in annotations[0]
|
||||
|
||||
with open(annname, 'r') as fid:
|
||||
for line in fid.readlines():
|
||||
line = line.strip().split(',')
|
||||
boxes.append([int(x) for x in list(map(float, line[:8]))])
|
||||
edge, text = self._split_edge(line)
|
||||
chars.append(text)
|
||||
text = self.convert_text(text)
|
||||
texts.append(text)
|
||||
edges.append(edge)
|
||||
labels.append(int(line[-1]))
|
||||
return dict(
|
||||
boxes=boxes, edges=edges, texts=texts, chars=chars, labels=labels)
|
||||
boxes, texts, text_inds, labels, edges = [], [], [], [], []
|
||||
for ann in annotations:
|
||||
box = ann['box']
|
||||
x_list, y_list = box[0:8:2], box[1:9:2]
|
||||
sorted_x_list, sorted_y_list = sort_vertex(x_list, y_list)
|
||||
sorted_box = []
|
||||
for x, y in zip(sorted_x_list, sorted_y_list):
|
||||
sorted_box.append(x)
|
||||
sorted_box.append(y)
|
||||
boxes.append(sorted_box)
|
||||
text = ann['text']
|
||||
texts.append(ann['text'])
|
||||
text_ind = [self.dict[c] for c in text if c in self.dict]
|
||||
text_inds.append(text_ind)
|
||||
labels.append(ann['label'])
|
||||
edges.append(ann.get('edge', 0))
|
||||
|
||||
def format_results(self, results):
|
||||
boxes = torch.Tensor(results['boxes'])[:, [0, 1, 4, 5]].cuda()
|
||||
|
||||
if 'nodes' in results:
|
||||
nodes, edges = results['nodes'], results['edges']
|
||||
labels = nodes.argmax(-1)
|
||||
num_nodes = nodes.size(0)
|
||||
edges = edges[:, -1].view(num_nodes, num_nodes)
|
||||
else:
|
||||
labels = torch.Tensor(results['labels']).cuda()
|
||||
edges = torch.Tensor(results['edges']).cuda()
|
||||
boxes = torch.cat([boxes, labels[:, None].float()], -1)
|
||||
|
||||
return {
|
||||
**{
|
||||
k: v
|
||||
for k, v in results.items() if k not in ['boxes', 'edges']
|
||||
}, 'boxes': boxes,
|
||||
'edges': edges,
|
||||
'points': results['boxes']
|
||||
}
|
||||
|
||||
def plot(self, results):
|
||||
img_name = osp.join(self.img_prefix, results['filename'])
|
||||
img = plt.imread(img_name)
|
||||
plt.imshow(img)
|
||||
|
||||
boxes, texts = results['points'], results['chars']
|
||||
num_nodes = len(boxes)
|
||||
if 'scores' in results:
|
||||
scores = results['scores']
|
||||
else:
|
||||
scores = np.ones(num_nodes)
|
||||
for box, text, score in zip(boxes, texts, scores):
|
||||
xs, ys = [], []
|
||||
for idx in range(0, 10, 2):
|
||||
xs.append(box[idx % 8])
|
||||
ys.append(box[(idx + 1) % 8])
|
||||
plt.plot(xs, ys, 'g')
|
||||
plt.annotate(
|
||||
'{}: {:.4f}'.format(text, score), (box[0], box[1]), color='g')
|
||||
|
||||
if 'nodes' in results:
|
||||
nodes = results['nodes']
|
||||
inds = nodes.argmax(-1)
|
||||
else:
|
||||
nodes = np.ones((num_nodes, 3))
|
||||
inds = results['labels']
|
||||
for i in range(num_nodes):
|
||||
plt.annotate(
|
||||
'{}: {:.4f}'.format(
|
||||
self.idx_to_cls(inds[i] - 1), nodes[i, inds[i]]),
|
||||
(boxes[i][6], boxes[i][7]),
|
||||
color='r' if inds[i] == 1 else 'b')
|
||||
edges = results['edges']
|
||||
if 'nodes' not in results:
|
||||
edges = (edges[:, None] == edges[None]).float()
|
||||
for j in range(i + 1, num_nodes):
|
||||
edge_score = max(edges[i][j], edges[j][i])
|
||||
if edge_score > self.thresholds['edge']:
|
||||
x1 = sum(boxes[i][:3:2]) // 2
|
||||
y1 = sum(boxes[i][3:6:2]) // 2
|
||||
x2 = sum(boxes[j][:3:2]) // 2
|
||||
y2 = sum(boxes[j][3:6:2]) // 2
|
||||
plt.plot((x1, x2), (y1, y2), 'r')
|
||||
plt.annotate(
|
||||
'{:.4f}'.format(edge_score),
|
||||
((x1 + x2) // 2, (y1 + y2) // 2),
|
||||
color='r')
|
||||
|
||||
def compute_relation(self, boxes):
|
||||
x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
|
||||
x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
|
||||
ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
|
||||
dxs = (x1s[:, 0][None] - x1s) / self.norm
|
||||
dys = (y1s[:, 0][None] - y1s) / self.norm
|
||||
xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
|
||||
whs = ws / hs + np.zeros_like(xhhs)
|
||||
relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
|
||||
bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
|
||||
return relations, bboxes
|
||||
|
||||
def ann_numpy(self, results):
|
||||
boxes, texts = results['boxes'], results['texts']
|
||||
boxes = np.array(boxes, np.int32)
|
||||
if boxes[0, 1] > boxes[0, -1]:
|
||||
boxes = boxes[:, [6, 7, 4, 5, 2, 3, 0, 1]]
|
||||
relations, bboxes = self.compute_relation(boxes)
|
||||
|
||||
labels = results.get('labels', None)
|
||||
if labels is not None:
|
||||
labels = np.array(labels, np.int32)
|
||||
edges = results.get('edges', None)
|
||||
if edges is not None:
|
||||
labels = labels[:, None]
|
||||
edges = np.array(edges)
|
||||
edges = (edges[:, None] == edges[None, :]).astype(np.int32)
|
||||
if self.directed:
|
||||
edges = (edges & labels == 1).astype(np.int32)
|
||||
np.fill_diagonal(edges, -1)
|
||||
labels = np.concatenate([labels, edges], -1)
|
||||
return dict(
|
||||
bboxes=bboxes,
|
||||
relations=relations,
|
||||
texts=self.pad_text(texts),
|
||||
ann_infos = dict(
|
||||
boxes=boxes,
|
||||
texts=texts,
|
||||
text_inds=text_inds,
|
||||
edges=edges,
|
||||
labels=labels)
|
||||
|
||||
def image_size(self, filename):
|
||||
img_path = osp.join(self.img_prefix, filename)
|
||||
img = Image.open(img_path)
|
||||
return img.size
|
||||
return self.list_to_numpy(ann_infos)
|
||||
|
||||
def load_annotations(self, ann_file):
|
||||
self.anns, data_infos = [], []
|
||||
def prepare_train_img(self, index):
|
||||
"""Get training data and annotations from pipeline.
|
||||
|
||||
self.gts = dict()
|
||||
img_infos = mmcv.list_from_file(ann_file)
|
||||
for img_info in img_infos:
|
||||
filename, annname = img_info.split()
|
||||
results = self.parse_lines(annname)
|
||||
width, height = self.image_size(filename)
|
||||
Args:
|
||||
index (int): Index of data.
|
||||
|
||||
data_infos.append(
|
||||
dict(filename=filename, width=width, height=height))
|
||||
ann = self.ann_numpy(results)
|
||||
self.anns.append(ann)
|
||||
Returns:
|
||||
dict: Training data and annotation after pipeline with new keys \
|
||||
introduced by pipeline.
|
||||
"""
|
||||
img_ann_info = self.data_infos[index]
|
||||
img_info = {
|
||||
'filename': img_ann_info['file_name'],
|
||||
'height': img_ann_info['height'],
|
||||
'width': img_ann_info['width']
|
||||
}
|
||||
ann_info = self._parse_anno_info(img_ann_info['annotations'])
|
||||
results = dict(img_info=img_info, ann_info=ann_info)
|
||||
|
||||
return data_infos
|
||||
self.pre_pipeline(results)
|
||||
|
||||
def pad_text(self, texts):
|
||||
max_len = max([len(text) for text in texts])
|
||||
padded_texts = -np.ones((len(texts), max_len), np.int32)
|
||||
for idx, text in enumerate(texts):
|
||||
padded_texts[idx, :len(text)] = np.array(text)
|
||||
return padded_texts
|
||||
|
||||
def get_ann_info(self, idx):
|
||||
return self.anns[idx]
|
||||
|
||||
def prepare_test_img(self, idx):
|
||||
return self.prepare_train_img(idx)
|
||||
return self.pipeline(results)
|
||||
|
||||
def evaluate(self,
|
||||
results,
|
||||
|
@ -278,18 +154,65 @@ class KIEDataset(CustomDataset):
|
|||
|
||||
def compute_macro_f1(self, results, ignores=[]):
|
||||
node_preds = []
|
||||
for result in results:
|
||||
node_gts = []
|
||||
for idx, result in enumerate(results):
|
||||
node_preds.append(result['nodes'])
|
||||
node_preds = torch.cat(node_preds)
|
||||
box_ann_infos = self.data_infos[idx]['annotations']
|
||||
node_gt = [box_ann_info['label'] for box_ann_info in box_ann_infos]
|
||||
node_gts.append(torch.Tensor(node_gt))
|
||||
|
||||
node_gts = [
|
||||
torch.from_numpy(ann['labels'][:, 0]).to(node_preds.device)
|
||||
for ann in self.anns
|
||||
]
|
||||
node_gts = torch.cat(node_gts)
|
||||
node_preds = torch.cat(node_preds)
|
||||
node_gts = torch.cat(node_gts).int().to(node_preds.device)
|
||||
|
||||
node_f1s = compute_f1_score(node_preds, node_gts, ignores)
|
||||
|
||||
return {
|
||||
'macro_f1': node_f1s.mean(),
|
||||
}
|
||||
|
||||
def list_to_numpy(self, ann_infos):
|
||||
"""Convert list to np.ndarray."""
|
||||
boxes, text_inds = ann_infos['boxes'], ann_infos['text_inds']
|
||||
boxes = np.array(boxes, np.int32)
|
||||
relations, bboxes = self.compute_relation(boxes)
|
||||
|
||||
labels = ann_infos.get('labels', None)
|
||||
if labels is not None:
|
||||
labels = np.array(labels, np.int32)
|
||||
edges = ann_infos.get('edges', None)
|
||||
if edges is not None:
|
||||
labels = labels[:, None]
|
||||
edges = np.array(edges)
|
||||
edges = (edges[:, None] == edges[None, :]).astype(np.int32)
|
||||
if self.directed:
|
||||
edges = (edges & labels == 1).astype(np.int32)
|
||||
np.fill_diagonal(edges, -1)
|
||||
labels = np.concatenate([labels, edges], -1)
|
||||
padded_text_inds = self.pad_text_ind(text_inds)
|
||||
|
||||
return dict(
|
||||
bboxes=bboxes,
|
||||
relations=relations,
|
||||
texts=padded_text_inds,
|
||||
labels=labels)
|
||||
|
||||
def pad_text_ind(self, text_inds):
|
||||
"""Pad text index to same length."""
|
||||
max_len = max([len(text_ind) for text_ind in text_inds])
|
||||
padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
|
||||
for idx, text_ind in enumerate(text_inds):
|
||||
padded_text_inds[idx, :len(text_ind)] = np.array(text_ind)
|
||||
return padded_text_inds
|
||||
|
||||
def compute_relation(self, boxes):
|
||||
"""Compute relation between every two boxes."""
|
||||
x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
|
||||
x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
|
||||
ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
|
||||
dxs = (x1s[:, 0][None] - x1s) / self.norm
|
||||
dys = (y1s[:, 0][None] - y1s) / self.norm
|
||||
xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
|
||||
whs = ws / hs + np.zeros_like(xhhs)
|
||||
relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
|
||||
bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
|
||||
return relations, bboxes
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
import warnings
|
||||
|
||||
import mmcv
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mmdet.core import bbox2roi
|
||||
from mmdet.models.builder import DETECTORS, build_roi_extractor
|
||||
from mmdet.models.detectors import SingleStageDetector
|
||||
from mmocr.core import imshow_edge_node
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
|
@ -13,6 +17,9 @@ class SDMGR(SingleStageDetector):
|
|||
|
||||
Args:
|
||||
visual_modality (bool): Whether use the visual modality.
|
||||
class_list (None | str): Mapping file of class index to
|
||||
class name. If None, class index will be shown in
|
||||
`show_results`, else class name.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -26,7 +33,8 @@ class SDMGR(SingleStageDetector):
|
|||
visual_modality=False,
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
pretrained=None):
|
||||
pretrained=None,
|
||||
class_list=None):
|
||||
super().__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
|
||||
pretrained)
|
||||
self.visual_modality = visual_modality
|
||||
|
@ -38,6 +46,7 @@ class SDMGR(SingleStageDetector):
|
|||
self.maxpool = nn.MaxPool2d(extractor['roi_layer']['output_size'])
|
||||
else:
|
||||
self.extractor = None
|
||||
self.class_list = class_list
|
||||
|
||||
def forward_train(self, img, img_metas, relations, texts, gt_bboxes,
|
||||
gt_labels):
|
||||
|
@ -85,3 +94,61 @@ class SDMGR(SingleStageDetector):
|
|||
feats = self.maxpool(self.extractor([x], bbox2roi(gt_bboxes)))
|
||||
return feats.view(feats.size(0), -1)
|
||||
return None
|
||||
|
||||
def show_result(self,
|
||||
img,
|
||||
result,
|
||||
boxes,
|
||||
win_name='',
|
||||
show=False,
|
||||
wait_time=0,
|
||||
out_file=None,
|
||||
**kwargs):
|
||||
"""Draw `result` on `img`.
|
||||
|
||||
Args:
|
||||
img (str or tensor): The image to be displayed.
|
||||
result (dict): The results to draw on `img`.
|
||||
boxes (list): Bbox of img.
|
||||
win_name (str): The window name.
|
||||
wait_time (int): Value of waitKey param.
|
||||
Default: 0.
|
||||
show (bool): Whether to show the image.
|
||||
Default: False.
|
||||
out_file (str or None): The output filename.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
img (tensor): Only if not `show` or `out_file`.
|
||||
"""
|
||||
img = mmcv.imread(img)
|
||||
img = img.copy()
|
||||
|
||||
idx_to_cls = {}
|
||||
if self.class_list is not None:
|
||||
with open(self.class_list, 'r') as fr:
|
||||
for line in fr:
|
||||
line = line.strip().split()
|
||||
class_idx, class_label = line
|
||||
idx_to_cls[class_idx] = class_label
|
||||
|
||||
# if out_file specified, do not show image in window
|
||||
if out_file is not None:
|
||||
show = False
|
||||
|
||||
img = imshow_edge_node(
|
||||
img,
|
||||
result,
|
||||
boxes,
|
||||
idx_to_cls=idx_to_cls,
|
||||
show=show,
|
||||
win_name=win_name,
|
||||
wait_time=wait_time,
|
||||
out_file=out_file)
|
||||
|
||||
if not (show or out_file):
|
||||
warnings.warn('show==False and out_file is not specified, only '
|
||||
'result image will be returned')
|
||||
return img
|
||||
|
||||
return img
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
from mmcv import Config
|
||||
from mmcv.image import tensor2imgs
|
||||
from mmcv.parallel import MMDataParallel
|
||||
from mmcv.runner import load_checkpoint
|
||||
|
||||
from mmocr.datasets import build_dataloader, build_dataset
|
||||
from mmocr.models import build_detector
|
||||
|
||||
|
||||
def test(model, data_loader, show=False, out_dir=None):
|
||||
model.eval()
|
||||
results = []
|
||||
dataset = data_loader.dataset
|
||||
prog_bar = mmcv.ProgressBar(len(dataset))
|
||||
for i, data in enumerate(data_loader):
|
||||
with torch.no_grad():
|
||||
result = model(return_loss=False, rescale=True, **data)
|
||||
|
||||
batch_size = len(result)
|
||||
if show or out_dir:
|
||||
img_tensor = data['img'].data[0]
|
||||
img_metas = data['img_metas'].data[0]
|
||||
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
|
||||
assert len(imgs) == len(img_metas)
|
||||
gt_bboxes = [data['gt_bboxes'].data[0][0].numpy().tolist()]
|
||||
|
||||
for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
|
||||
h, w, _ = img_meta['img_shape']
|
||||
img_show = img[:h, :w, :]
|
||||
|
||||
if out_dir:
|
||||
out_file = osp.join(out_dir, img_meta['ori_filename'])
|
||||
else:
|
||||
out_file = None
|
||||
|
||||
model.module.show_result(
|
||||
img_show,
|
||||
result[i],
|
||||
gt_bboxes[i],
|
||||
show=show,
|
||||
out_file=out_file)
|
||||
|
||||
for _ in range(batch_size):
|
||||
prog_bar.update()
|
||||
return results
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='MMOCR visualize for kie model.')
|
||||
parser.add_argument('config', help='Test config file path.')
|
||||
parser.add_argument('checkpoint', help='Checkpoint file.')
|
||||
parser.add_argument('--show', action='store_true', help='Show results.')
|
||||
parser.add_argument(
|
||||
'--show-dir', help='Directory where the output images will be saved.')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
assert args.show or args.show_dir, \
|
||||
('Please specify at least one operation (show the results'
|
||||
' / save the results) with the argument "--show" or "--show-dir".')
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
# import modules from string list.
|
||||
if cfg.get('custom_imports', None):
|
||||
from mmcv.utils import import_modules_from_strings
|
||||
import_modules_from_strings(**cfg['custom_imports'])
|
||||
# set cudnn_benchmark
|
||||
if cfg.get('cudnn_benchmark', False):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
cfg.model.pretrained = None
|
||||
|
||||
distributed = False
|
||||
|
||||
# build the dataloader
|
||||
dataset = build_dataset(cfg.data.test)
|
||||
data_loader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=1,
|
||||
workers_per_gpu=cfg.data.workers_per_gpu,
|
||||
dist=distributed,
|
||||
shuffle=False)
|
||||
|
||||
# build the model and load checkpoint
|
||||
cfg.model.train_cfg = None
|
||||
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
|
||||
load_checkpoint(model, args.checkpoint, map_location='cpu')
|
||||
|
||||
model = MMDataParallel(model, device_ids=[0])
|
||||
test(model, data_loader, args.show, args.show_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,21 @@
|
|||
#!/bin/bash
|
||||
|
||||
DATE=`date +%Y-%m-%d`
|
||||
TIME=`date +"%H-%M-%S"`
|
||||
|
||||
if [ $# -lt 3 ]
|
||||
then
|
||||
echo "Usage: bash $0 CONFIG CHECKPOINT SHOW_DIR"
|
||||
exit
|
||||
fi
|
||||
|
||||
CONFIG_FILE=$1
|
||||
CHECKPOINT=$2
|
||||
SHOW_DIR=$3_${DATE}_${TIME}
|
||||
|
||||
mkdir ${SHOW_DIR} -p &&
|
||||
|
||||
python tools/kie_test_imgs.py \
|
||||
${CONFIG_FILE} \
|
||||
${CHECKPOINT} \
|
||||
--show-dir ${SHOW_DIR}
|
Loading…
Reference in New Issue