fix #21: refactor kie dataset & add show_results

pull/2/head
Hongbin Sun 2021-04-05 10:34:37 +08:00
parent 9d62bdf84c
commit 9af0bed144
7 changed files with 456 additions and 304 deletions

View File

@ -1,5 +1,3 @@
dataset_type = 'KIEDataset'
data_root = 'data/wildreceipt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
max_scale, min_scale = 1024, 512
@ -27,33 +25,35 @@ test_pipeline = [
dict(type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes'])
]
vocab_file = 'dict.txt'
class_file = 'class_list.txt'
dataset_type = 'KIEDataset'
data_root = 'data/wildreceipt/'
loader = dict(
type='HardDiskLoader',
repeat=1,
parser=dict(
type='LineJsonParser',
keys=['file_name', 'height', 'width', 'annotations']))
train = dict(
type=dataset_type,
ann_file=data_root + 'train.txt',
pipeline=train_pipeline,
img_prefix=data_root,
loader=loader,
dict_file=data_root + 'dict.txt',
test_mode=False)
test = dict(
type=dataset_type,
ann_file=data_root + 'test.txt',
pipeline=test_pipeline,
img_prefix=data_root,
loader=loader,
dict_file=data_root + 'dict.txt',
test_mode=True)
data = dict(
samples_per_gpu=4,
workers_per_gpu=0,
train=dict(
type=dataset_type,
ann_file='train.txt',
pipeline=train_pipeline,
data_root=data_root,
vocab_file=vocab_file,
class_file=class_file),
val=dict(
type=dataset_type,
ann_file='test.txt',
pipeline=test_pipeline,
data_root=data_root,
vocab_file=vocab_file,
class_file=class_file),
test=dict(
type=dataset_type,
ann_file='test.txt',
pipeline=test_pipeline,
data_root=data_root,
vocab_file=vocab_file,
class_file=class_file))
samples_per_gpu=4, workers_per_gpu=0, train=train, val=test, test=test)
evaluation = dict(
interval=1,
@ -69,7 +69,8 @@ model = dict(
type='SDMGRHead', visual_dim=16, num_chars=92, num_classes=26),
visual_modality=False,
train_cfg=None,
test_cfg=None)
test_cfg=None,
class_list=data_root + 'class_list.txt')
optimizer = dict(type='Adam', weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
@ -82,16 +83,7 @@ lr_config = dict(
total_epochs = 60
checkpoint_config = dict(interval=1)
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(
# type='PaviLoggerHook',
# add_last_ckpt=True,
# interval=5,
# init_kwargs=dict(project='kie')),
])
log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None

View File

@ -1,5 +1,3 @@
dataset_type = 'KIEDataset'
data_root = 'data/wildreceipt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
max_scale, min_scale = 1024, 512
@ -27,33 +25,35 @@ test_pipeline = [
dict(type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes'])
]
vocab_file = 'dict.txt'
class_file = 'class_list.txt'
dataset_type = 'KIEDataset'
data_root = 'data/wildreceipt/'
loader = dict(
type='HardDiskLoader',
repeat=1,
parser=dict(
type='LineJsonParser',
keys=['file_name', 'height', 'width', 'annotations']))
train = dict(
type=dataset_type,
ann_file=data_root + 'train.txt',
pipeline=train_pipeline,
img_prefix=data_root,
loader=loader,
dict_file=data_root + 'dict.txt',
test_mode=False)
test = dict(
type=dataset_type,
ann_file=data_root + 'test.txt',
pipeline=test_pipeline,
img_prefix=data_root,
loader=loader,
dict_file=data_root + 'dict.txt',
test_mode=True)
data = dict(
samples_per_gpu=4,
workers_per_gpu=0,
train=dict(
type=dataset_type,
ann_file='train.txt',
pipeline=train_pipeline,
data_root=data_root,
vocab_file=vocab_file,
class_file=class_file),
val=dict(
type=dataset_type,
ann_file='test.txt',
pipeline=test_pipeline,
data_root=data_root,
vocab_file=vocab_file,
class_file=class_file),
test=dict(
type=dataset_type,
ann_file='test.txt',
pipeline=test_pipeline,
data_root=data_root,
vocab_file=vocab_file,
class_file=class_file))
samples_per_gpu=4, workers_per_gpu=0, train=train, val=test, test=test)
evaluation = dict(
interval=1,
@ -69,7 +69,8 @@ model = dict(
type='SDMGRHead', visual_dim=16, num_chars=92, num_classes=26),
visual_modality=True,
train_cfg=None,
test_cfg=None)
test_cfg=None,
class_list=data_root + 'class_list.txt')
optimizer = dict(type='Adam', weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
@ -82,16 +83,7 @@ lr_config = dict(
total_epochs = 60
checkpoint_config = dict(interval=1)
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(
# type='PaviLoggerHook',
# add_last_ckpt=True,
# interval=5,
# init_kwargs=dict(project='kie')),
])
log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None

View File

@ -4,6 +4,7 @@ import warnings
import cv2
import mmcv
import numpy as np
import torch
from matplotlib import pyplot as plt
import mmocr.utils as utils
@ -367,3 +368,52 @@ def imshow_text_label(img,
mmcv.imwrite(img, out_file)
return img
def imshow_edge_node(img,
result,
boxes,
idx_to_cls={},
show=False,
win_name='',
wait_time=-1,
out_file=None):
img = mmcv.imread(img)
h, w = img.shape[:2]
pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255
max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1)
node_pred_label = max_idx.numpy().tolist()
node_pred_score = max_value.numpy().tolist()
for i, box in enumerate(boxes):
new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]],
[box[0], box[3]]]
Pts = np.array([new_box], np.int32)
cv2.polylines(
img, [Pts.reshape((-1, 1, 2))],
True,
color=(255, 255, 0),
thickness=1)
x_min = int(min([point[0] for point in new_box]))
y_min = int(min([point[1] for point in new_box]))
pred_label = str(node_pred_label[i])
if pred_label in idx_to_cls:
pred_label = idx_to_cls[pred_label]
pred_score = '{:.2f}'.format(node_pred_score[i])
text = pred_label + '(' + pred_score + ')'
cv2.putText(pred_img, text, (x_min * 2, y_min),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
vis_img[:, :w] = img
vis_img[:, w:] = pred_img
if show:
mmcv.imshow(vis_img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(vis_img, out_file)
return vis_img

View File

@ -1,261 +1,137 @@
import copy
from os import path as osp
import mmcv
import numpy as np
import torch
from matplotlib import pyplot as plt
from PIL import Image
import mmocr.utils as utils
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom import CustomDataset
from mmocr.core import compute_f1_score
from mmocr.datasets.base_dataset import BaseDataset
from mmocr.datasets.pipelines.crop import sort_vertex
@DATASETS.register_module()
class KIEDataset(CustomDataset):
class KIEDataset(BaseDataset):
"""
Args:
ann_file (str): Annotation file path.
pipeline (list[dict]): Processing pipeline.
loader (dict): Dictionary to construct loader
to load annotation infos.
img_prefix (str, optional): Image prefix to generate full
image path.
test_mode (bool, optional): If True, try...except will
be turned off in __getitem__.
dict_file (str): Character dict file path.
norm (float): Norm to map value from one range to another.
"""
def __init__(self,
ann_file,
pipeline=None,
data_root=None,
loader,
dict_file,
img_prefix='',
ann_prefix='',
vocab_file=None,
class_file=None,
pipeline=None,
norm=10.,
thresholds=dict(edge=0.5),
directed=False,
test_mode=True,
**kwargs):
self.ann_prefix = ann_prefix
self.norm = norm
self.thresholds = thresholds
self.directed = directed
if data_root is not None:
if not osp.isabs(ann_file):
self.ann_file = osp.join(data_root, ann_file)
if not (ann_prefix is None or osp.isabs(ann_prefix)):
self.ann_prefix = osp.join(data_root, ann_prefix)
self.vocab = dict({'': 0})
vocab_file = osp.join(data_root, vocab_file)
if osp.exists(vocab_file):
with open(vocab_file, 'r') as fid:
for idx, char in enumerate(fid.readlines(), 1):
self.vocab[char.strip('\n')] = idx
else:
self.construct_dict(self.ann_file)
with open(vocab_file, 'w') as fid:
for key in self.vocab:
if key:
fid.write('{}\n'.format(key))
super().__init__(
ann_file,
loader,
pipeline,
data_root=data_root,
img_prefix=img_prefix,
**kwargs)
test_mode=test_mode)
assert osp.exists(dict_file)
self.idx_to_cls = dict()
with open(osp.join(data_root, class_file), 'r') as fid:
for line in fid.readlines():
idx, cls = line.split()
self.idx_to_cls[int(idx)] = cls
self.norm = norm
self.directed = directed
@staticmethod
def _split_edge(line):
text = ','.join(line[8:-1])
if ';' in text and text.split(';')[0].isdecimal():
edge, text = text.split(';', 1)
edge = int(edge)
else:
edge = 0
return edge, text
self.dict = dict({'': 0})
with open(dict_file, 'r') as fr:
idx = 1
for line in fr:
char = line.strip()
self.dict[char] = idx
idx += 1
def construct_dict(self, ann_file):
img_infos = mmcv.list_from_file(ann_file)
for img_info in img_infos:
_, annname = img_info.split()
if self.ann_prefix:
annname = osp.join(self.ann_prefix, annname)
with open(annname, 'r') as fid:
lines = fid.readlines()
def pre_pipeline(self, results):
results['img_prefix'] = self.img_prefix
results['bbox_fields'] = []
for line in lines:
line = line.strip().split(',')
_, text = self._split_edge(line)
for c in text:
if c not in self.vocab:
self.vocab[c] = len(self.vocab)
self.vocab = dict(
{k: idx
for idx, k in enumerate(sorted(self.vocab.keys()))})
def _parse_anno_info(self, annotations):
"""Parse annotations of boxes, texts and labels for one image.
Args:
annotations (list[dict]): Annotations of one image, where
each dict is for one character.
def convert_text(self, text):
return [self.vocab[c] for c in text if c in self.vocab]
Returns:
dict: A dict containing the following keys:
def parse_lines(self, annname):
boxes, edges, texts, chars, labels = [], [], [], [], []
- bboxes (np.ndarray): Bbox in one image with shape:
box_num * 4.
- relations (np.ndarray): Relations between bbox with shape:
box_num * box_num * D.
- texts (np.ndarray): Text index with shape:
box_num * text_max_len.
- labels (np.ndarray): Box Labels with shape:
box_num * (box_num + 1).
"""
if self.ann_prefix:
annname = osp.join(self.ann_prefix, annname)
assert utils.is_type_list(annotations, dict)
assert 'box' in annotations[0]
assert 'text' in annotations[0]
assert 'label' in annotations[0]
with open(annname, 'r') as fid:
for line in fid.readlines():
line = line.strip().split(',')
boxes.append([int(x) for x in list(map(float, line[:8]))])
edge, text = self._split_edge(line)
chars.append(text)
text = self.convert_text(text)
texts.append(text)
edges.append(edge)
labels.append(int(line[-1]))
return dict(
boxes=boxes, edges=edges, texts=texts, chars=chars, labels=labels)
boxes, texts, text_inds, labels, edges = [], [], [], [], []
for ann in annotations:
box = ann['box']
x_list, y_list = box[0:8:2], box[1:9:2]
sorted_x_list, sorted_y_list = sort_vertex(x_list, y_list)
sorted_box = []
for x, y in zip(sorted_x_list, sorted_y_list):
sorted_box.append(x)
sorted_box.append(y)
boxes.append(sorted_box)
text = ann['text']
texts.append(ann['text'])
text_ind = [self.dict[c] for c in text if c in self.dict]
text_inds.append(text_ind)
labels.append(ann['label'])
edges.append(ann.get('edge', 0))
def format_results(self, results):
boxes = torch.Tensor(results['boxes'])[:, [0, 1, 4, 5]].cuda()
if 'nodes' in results:
nodes, edges = results['nodes'], results['edges']
labels = nodes.argmax(-1)
num_nodes = nodes.size(0)
edges = edges[:, -1].view(num_nodes, num_nodes)
else:
labels = torch.Tensor(results['labels']).cuda()
edges = torch.Tensor(results['edges']).cuda()
boxes = torch.cat([boxes, labels[:, None].float()], -1)
return {
**{
k: v
for k, v in results.items() if k not in ['boxes', 'edges']
}, 'boxes': boxes,
'edges': edges,
'points': results['boxes']
}
def plot(self, results):
img_name = osp.join(self.img_prefix, results['filename'])
img = plt.imread(img_name)
plt.imshow(img)
boxes, texts = results['points'], results['chars']
num_nodes = len(boxes)
if 'scores' in results:
scores = results['scores']
else:
scores = np.ones(num_nodes)
for box, text, score in zip(boxes, texts, scores):
xs, ys = [], []
for idx in range(0, 10, 2):
xs.append(box[idx % 8])
ys.append(box[(idx + 1) % 8])
plt.plot(xs, ys, 'g')
plt.annotate(
'{}: {:.4f}'.format(text, score), (box[0], box[1]), color='g')
if 'nodes' in results:
nodes = results['nodes']
inds = nodes.argmax(-1)
else:
nodes = np.ones((num_nodes, 3))
inds = results['labels']
for i in range(num_nodes):
plt.annotate(
'{}: {:.4f}'.format(
self.idx_to_cls(inds[i] - 1), nodes[i, inds[i]]),
(boxes[i][6], boxes[i][7]),
color='r' if inds[i] == 1 else 'b')
edges = results['edges']
if 'nodes' not in results:
edges = (edges[:, None] == edges[None]).float()
for j in range(i + 1, num_nodes):
edge_score = max(edges[i][j], edges[j][i])
if edge_score > self.thresholds['edge']:
x1 = sum(boxes[i][:3:2]) // 2
y1 = sum(boxes[i][3:6:2]) // 2
x2 = sum(boxes[j][:3:2]) // 2
y2 = sum(boxes[j][3:6:2]) // 2
plt.plot((x1, x2), (y1, y2), 'r')
plt.annotate(
'{:.4f}'.format(edge_score),
((x1 + x2) // 2, (y1 + y2) // 2),
color='r')
def compute_relation(self, boxes):
x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
dxs = (x1s[:, 0][None] - x1s) / self.norm
dys = (y1s[:, 0][None] - y1s) / self.norm
xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
whs = ws / hs + np.zeros_like(xhhs)
relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
return relations, bboxes
def ann_numpy(self, results):
boxes, texts = results['boxes'], results['texts']
boxes = np.array(boxes, np.int32)
if boxes[0, 1] > boxes[0, -1]:
boxes = boxes[:, [6, 7, 4, 5, 2, 3, 0, 1]]
relations, bboxes = self.compute_relation(boxes)
labels = results.get('labels', None)
if labels is not None:
labels = np.array(labels, np.int32)
edges = results.get('edges', None)
if edges is not None:
labels = labels[:, None]
edges = np.array(edges)
edges = (edges[:, None] == edges[None, :]).astype(np.int32)
if self.directed:
edges = (edges & labels == 1).astype(np.int32)
np.fill_diagonal(edges, -1)
labels = np.concatenate([labels, edges], -1)
return dict(
bboxes=bboxes,
relations=relations,
texts=self.pad_text(texts),
ann_infos = dict(
boxes=boxes,
texts=texts,
text_inds=text_inds,
edges=edges,
labels=labels)
def image_size(self, filename):
img_path = osp.join(self.img_prefix, filename)
img = Image.open(img_path)
return img.size
return self.list_to_numpy(ann_infos)
def load_annotations(self, ann_file):
self.anns, data_infos = [], []
def prepare_train_img(self, index):
"""Get training data and annotations from pipeline.
self.gts = dict()
img_infos = mmcv.list_from_file(ann_file)
for img_info in img_infos:
filename, annname = img_info.split()
results = self.parse_lines(annname)
width, height = self.image_size(filename)
Args:
index (int): Index of data.
data_infos.append(
dict(filename=filename, width=width, height=height))
ann = self.ann_numpy(results)
self.anns.append(ann)
Returns:
dict: Training data and annotation after pipeline with new keys \
introduced by pipeline.
"""
img_ann_info = self.data_infos[index]
img_info = {
'filename': img_ann_info['file_name'],
'height': img_ann_info['height'],
'width': img_ann_info['width']
}
ann_info = self._parse_anno_info(img_ann_info['annotations'])
results = dict(img_info=img_info, ann_info=ann_info)
return data_infos
self.pre_pipeline(results)
def pad_text(self, texts):
max_len = max([len(text) for text in texts])
padded_texts = -np.ones((len(texts), max_len), np.int32)
for idx, text in enumerate(texts):
padded_texts[idx, :len(text)] = np.array(text)
return padded_texts
def get_ann_info(self, idx):
return self.anns[idx]
def prepare_test_img(self, idx):
return self.prepare_train_img(idx)
return self.pipeline(results)
def evaluate(self,
results,
@ -278,18 +154,65 @@ class KIEDataset(CustomDataset):
def compute_macro_f1(self, results, ignores=[]):
node_preds = []
for result in results:
node_gts = []
for idx, result in enumerate(results):
node_preds.append(result['nodes'])
node_preds = torch.cat(node_preds)
box_ann_infos = self.data_infos[idx]['annotations']
node_gt = [box_ann_info['label'] for box_ann_info in box_ann_infos]
node_gts.append(torch.Tensor(node_gt))
node_gts = [
torch.from_numpy(ann['labels'][:, 0]).to(node_preds.device)
for ann in self.anns
]
node_gts = torch.cat(node_gts)
node_preds = torch.cat(node_preds)
node_gts = torch.cat(node_gts).int().to(node_preds.device)
node_f1s = compute_f1_score(node_preds, node_gts, ignores)
return {
'macro_f1': node_f1s.mean(),
}
def list_to_numpy(self, ann_infos):
"""Convert list to np.ndarray."""
boxes, text_inds = ann_infos['boxes'], ann_infos['text_inds']
boxes = np.array(boxes, np.int32)
relations, bboxes = self.compute_relation(boxes)
labels = ann_infos.get('labels', None)
if labels is not None:
labels = np.array(labels, np.int32)
edges = ann_infos.get('edges', None)
if edges is not None:
labels = labels[:, None]
edges = np.array(edges)
edges = (edges[:, None] == edges[None, :]).astype(np.int32)
if self.directed:
edges = (edges & labels == 1).astype(np.int32)
np.fill_diagonal(edges, -1)
labels = np.concatenate([labels, edges], -1)
padded_text_inds = self.pad_text_ind(text_inds)
return dict(
bboxes=bboxes,
relations=relations,
texts=padded_text_inds,
labels=labels)
def pad_text_ind(self, text_inds):
"""Pad text index to same length."""
max_len = max([len(text_ind) for text_ind in text_inds])
padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
for idx, text_ind in enumerate(text_inds):
padded_text_inds[idx, :len(text_ind)] = np.array(text_ind)
return padded_text_inds
def compute_relation(self, boxes):
"""Compute relation between every two boxes."""
x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
dxs = (x1s[:, 0][None] - x1s) / self.norm
dys = (y1s[:, 0][None] - y1s) / self.norm
xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
whs = ws / hs + np.zeros_like(xhhs)
relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
return relations, bboxes

View File

@ -1,9 +1,13 @@
import warnings
import mmcv
from torch import nn
from torch.nn import functional as F
from mmdet.core import bbox2roi
from mmdet.models.builder import DETECTORS, build_roi_extractor
from mmdet.models.detectors import SingleStageDetector
from mmocr.core import imshow_edge_node
@DETECTORS.register_module()
@ -13,6 +17,9 @@ class SDMGR(SingleStageDetector):
Args:
visual_modality (bool): Whether use the visual modality.
class_list (None | str): Mapping file of class index to
class name. If None, class index will be shown in
`show_results`, else class name.
"""
def __init__(self,
@ -26,7 +33,8 @@ class SDMGR(SingleStageDetector):
visual_modality=False,
train_cfg=None,
test_cfg=None,
pretrained=None):
pretrained=None,
class_list=None):
super().__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
pretrained)
self.visual_modality = visual_modality
@ -38,6 +46,7 @@ class SDMGR(SingleStageDetector):
self.maxpool = nn.MaxPool2d(extractor['roi_layer']['output_size'])
else:
self.extractor = None
self.class_list = class_list
def forward_train(self, img, img_metas, relations, texts, gt_bboxes,
gt_labels):
@ -85,3 +94,61 @@ class SDMGR(SingleStageDetector):
feats = self.maxpool(self.extractor([x], bbox2roi(gt_bboxes)))
return feats.view(feats.size(0), -1)
return None
def show_result(self,
img,
result,
boxes,
win_name='',
show=False,
wait_time=0,
out_file=None,
**kwargs):
"""Draw `result` on `img`.
Args:
img (str or tensor): The image to be displayed.
result (dict): The results to draw on `img`.
boxes (list): Bbox of img.
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The output filename.
Default: None.
Returns:
img (tensor): Only if not `show` or `out_file`.
"""
img = mmcv.imread(img)
img = img.copy()
idx_to_cls = {}
if self.class_list is not None:
with open(self.class_list, 'r') as fr:
for line in fr:
line = line.strip().split()
class_idx, class_label = line
idx_to_cls[class_idx] = class_label
# if out_file specified, do not show image in window
if out_file is not None:
show = False
img = imshow_edge_node(
img,
result,
boxes,
idx_to_cls=idx_to_cls,
show=show,
win_name=win_name,
wait_time=wait_time,
out_file=out_file)
if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '
'result image will be returned')
return img
return img

View File

@ -0,0 +1,107 @@
import argparse
import os
import os.path as osp
import mmcv
import torch
from mmcv import Config
from mmcv.image import tensor2imgs
from mmcv.parallel import MMDataParallel
from mmcv.runner import load_checkpoint
from mmocr.datasets import build_dataloader, build_dataset
from mmocr.models import build_detector
def test(model, data_loader, show=False, out_dir=None):
model.eval()
results = []
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
batch_size = len(result)
if show or out_dir:
img_tensor = data['img'].data[0]
img_metas = data['img_metas'].data[0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
assert len(imgs) == len(img_metas)
gt_bboxes = [data['gt_bboxes'].data[0][0].numpy().tolist()]
for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
if out_dir:
out_file = osp.join(out_dir, img_meta['ori_filename'])
else:
out_file = None
model.module.show_result(
img_show,
result[i],
gt_bboxes[i],
show=show,
out_file=out_file)
for _ in range(batch_size):
prog_bar.update()
return results
def parse_args():
parser = argparse.ArgumentParser(
description='MMOCR visualize for kie model.')
parser.add_argument('config', help='Test config file path.')
parser.add_argument('checkpoint', help='Checkpoint file.')
parser.add_argument('--show', action='store_true', help='Show results.')
parser.add_argument(
'--show-dir', help='Directory where the output images will be saved.')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
assert args.show or args.show_dir, \
('Please specify at least one operation (show the results'
' / save the results) with the argument "--show" or "--show-dir".')
cfg = Config.fromfile(args.config)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
cfg.model.pretrained = None
distributed = False
# build the dataloader
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
load_checkpoint(model, args.checkpoint, map_location='cpu')
model = MMDataParallel(model, device_ids=[0])
test(model, data_loader, args.show, args.show_dir)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,21 @@
#!/bin/bash
DATE=`date +%Y-%m-%d`
TIME=`date +"%H-%M-%S"`
if [ $# -lt 3 ]
then
echo "Usage: bash $0 CONFIG CHECKPOINT SHOW_DIR"
exit
fi
CONFIG_FILE=$1
CHECKPOINT=$2
SHOW_DIR=$3_${DATE}_${TIME}
mkdir ${SHOW_DIR} -p &&
python tools/kie_test_imgs.py \
${CONFIG_FILE} \
${CHECKPOINT} \
--show-dir ${SHOW_DIR}