mmocr/tools/test_imgs.py
lizz 44ca9c2a61
Remove usage of \ (#49)
* Remove usage of \

Signed-off-by: lizz <lizz@sensetime.com>

* rebase

Signed-off-by: lizz <lizz@sensetime.com>

* typos

Signed-off-by: lizz <lizz@sensetime.com>

* Remove test dependency on tools/

Signed-off-by: lizz <lizz@sensetime.com>

* Remove usage of \

Signed-off-by: lizz <lizz@sensetime.com>

* rebase

Signed-off-by: lizz <lizz@sensetime.com>

* typos

Signed-off-by: lizz <lizz@sensetime.com>

* Remove test dependency on tools/

Signed-off-by: lizz <lizz@sensetime.com>

* typo

Signed-off-by: lizz <lizz@sensetime.com>

* KIE in keywords

Signed-off-by: lizz <lizz@sensetime.com>

* some renames

Signed-off-by: lizz <lizz@sensetime.com>

* kill isort skip

Signed-off-by: lizz <lizz@sensetime.com>

* aggregation discrimination

Signed-off-by: lizz <lizz@sensetime.com>

* aggregation discrimination

Signed-off-by: lizz <lizz@sensetime.com>

* tiny

Signed-off-by: lizz <lizz@sensetime.com>

* fix bug: model infer on cpu

Co-authored-by: Hongbin Sun <hongbin306@gmail.com>
2021-04-06 12:16:46 +00:00

168 lines
6.0 KiB
Python

import codecs
import os.path as osp
from argparse import ArgumentParser
import mmcv
import numpy as np
import torch
from mmcv.utils import ProgressBar
from mmdet.apis import inference_detector, init_detector
from mmocr.core.evaluation.utils import filter_result
from mmocr.models import build_detector # noqa: F401
def gen_target_path(target_root_path, src_name, suffix):
"""Gen target file path.
Args:
target_root_path (str): The target root path.
src_name (str): The source file name.
suffix (str): The suffix of target file.
"""
assert isinstance(target_root_path, str)
assert isinstance(src_name, str)
assert isinstance(suffix, str)
dir_name, file_name = osp.split(src_name)
name, file_suffix = osp.splitext(file_name)
return target_root_path + '/' + name + suffix
def save_2darray(mat, file_name):
"""Save 2d array to txt file.
Args:
mat (ndarray): 2d-array of shape (n, m).
file_name (str): The output file name.
"""
with codecs.open(file_name, 'w', 'utf-8') as fw:
for row in mat:
row_str = ','.join([str(x) for x in row])
fw.write(row_str + '\n')
def save_bboxes_quadrangles(bboxes_with_scores,
quadrangles_with_scores,
img_name,
out_bbox_txt_dir,
out_quadrangle_txt_dir,
score_thr=0.3,
save_score=True):
"""Save results of detected bounding boxes and quadrangles to txt file.
Args:
bboxes_with_scores (ndarray): Detected bboxes of shape (n,5).
quadrangles_with_scores (ndarray): Detected quadrangles of shape (n,9).
img_name (str): Image file name.
out_bbox_txt_dir (str): Dir of txt files to save detected bboxes
results.
out_quadrangle_txt_dir (str): Dir of txt files to save
quadrangle results.
score_thr (float, optional): Score threshold for bboxes.
save_score (bool, optional): Whether to save score at each line end
to search best threshold when evaluating.
"""
assert bboxes_with_scores.ndim == 2
assert bboxes_with_scores.shape[1] == 5 or bboxes_with_scores.shape[1] == 9
assert quadrangles_with_scores.ndim == 2
assert quadrangles_with_scores.shape[1] == 9
assert bboxes_with_scores.shape[0] >= quadrangles_with_scores.shape[0]
assert isinstance(img_name, str)
assert isinstance(out_bbox_txt_dir, str)
assert isinstance(out_quadrangle_txt_dir, str)
assert isinstance(score_thr, float)
assert score_thr >= 0 and score_thr < 1
# filter out invalid results
initial_valid_bboxes, valid_bbox_scores = filter_result(
bboxes_with_scores[:, :-1], bboxes_with_scores[:, -1], score_thr)
if initial_valid_bboxes.shape[1] == 4:
valid_bboxes = np.ndarray(
(initial_valid_bboxes.shape[0], 8)).astype(int)
idx_list = [0, 1, 2, 1, 2, 3, 0, 3]
for i in range(8):
valid_bboxes[:, i] = initial_valid_bboxes[:, idx_list[i]]
elif initial_valid_bboxes.shape[1] == 8:
valid_bboxes = initial_valid_bboxes
valid_quadrangles, valid_quadrangle_scores = filter_result(
quadrangles_with_scores[:, :-1], quadrangles_with_scores[:, -1],
score_thr)
# gen target file path
bbox_txt_file = gen_target_path(out_bbox_txt_dir, img_name, '.txt')
quadrangle_txt_file = gen_target_path(out_quadrangle_txt_dir, img_name,
'.txt')
# save txt
if save_score:
valid_bboxes = np.concatenate(
(valid_bboxes, valid_bbox_scores.reshape(-1, 1)), axis=1)
valid_quadrangles = np.concatenate(
(valid_quadrangles, valid_quadrangle_scores.reshape(-1, 1)),
axis=1)
save_2darray(valid_bboxes, bbox_txt_file)
save_2darray(valid_quadrangles, quadrangle_txt_file)
def main():
parser = ArgumentParser()
parser.add_argument('config', type=str, help='Config file')
parser.add_argument('checkpoint', type=str, help='Checkpoint file')
parser.add_argument('img_root', type=str, help='Image root path')
parser.add_argument('img_list', type=str, help='Image path list file')
parser.add_argument(
'--score-thr', type=float, default=0.5, help='Bbox score threshold')
parser.add_argument(
'--out-dir',
type=str,
default='./results',
help='Dir to save '
'visualize images '
'and bbox')
args = parser.parse_args()
assert args.score_thr > 0 and args.score_thr < 1
# build the model from a config file and a checkpoint file
device = 'cuda:' + str(torch.cuda.current_device())
model = init_detector(args.config, args.checkpoint, device=device)
if hasattr(model, 'module'):
model = model.module
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
0].pipeline
# Start Inference
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
mmcv.mkdir_or_exist(out_vis_dir)
total_img_num = sum([1 for _ in open(args.img_list)])
progressbar = ProgressBar(task_num=total_img_num)
with codecs.open(args.img_list, 'r', 'utf-8') as fr:
for line in fr:
progressbar.update()
img_path = args.img_root + '/' + line.strip()
if not osp.exists(img_path):
raise FileNotFoundError(img_path)
# Test a single image
result = inference_detector(model, img_path)
img_name = osp.basename(img_path)
out_file = osp.join(out_vis_dir, img_name)
kwargs_dict = {
'score_thr': args.score_thr,
'show': False,
'out_file': out_file
}
model.show_result(img_path, result, **kwargs_dict)
print(f'\nInference done, and results saved in {args.out_dir}\n')
if __name__ == '__main__':
main()