mirror of https://github.com/open-mmlab/mmocr.git
168 lines
6.0 KiB
Python
168 lines
6.0 KiB
Python
import codecs
|
|
import os.path as osp
|
|
from argparse import ArgumentParser
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
import torch
|
|
from mmcv.utils import ProgressBar
|
|
|
|
from mmdet.apis import inference_detector, init_detector
|
|
from mmocr.core.evaluation.utils import filter_result
|
|
from mmocr.models import build_detector # noqa: F401
|
|
|
|
|
|
def gen_target_path(target_root_path, src_name, suffix):
|
|
"""Gen target file path.
|
|
|
|
Args:
|
|
target_root_path (str): The target root path.
|
|
src_name (str): The source file name.
|
|
suffix (str): The suffix of target file.
|
|
"""
|
|
assert isinstance(target_root_path, str)
|
|
assert isinstance(src_name, str)
|
|
assert isinstance(suffix, str)
|
|
|
|
dir_name, file_name = osp.split(src_name)
|
|
name, file_suffix = osp.splitext(file_name)
|
|
return target_root_path + '/' + name + suffix
|
|
|
|
|
|
def save_2darray(mat, file_name):
|
|
"""Save 2d array to txt file.
|
|
|
|
Args:
|
|
mat (ndarray): 2d-array of shape (n, m).
|
|
file_name (str): The output file name.
|
|
"""
|
|
with codecs.open(file_name, 'w', 'utf-8') as fw:
|
|
for row in mat:
|
|
row_str = ','.join([str(x) for x in row])
|
|
fw.write(row_str + '\n')
|
|
|
|
|
|
def save_bboxes_quadrangels(bboxes_with_scores,
|
|
qudrangels_with_scores,
|
|
img_name,
|
|
out_bbox_txt_dir,
|
|
out_quadrangel_txt_dir,
|
|
score_thr=0.3,
|
|
save_score=True):
|
|
"""Save results of detected bounding boxes and quadrangels to txt file.
|
|
|
|
Args:
|
|
bboxes_with_scores (ndarray): Detected bboxes of shape (n,5).
|
|
qudrangels_with_scores (ndarray): Detected quadrangels of shape (n,9).
|
|
img_name (str): Image file name.
|
|
out_bbox_txt_dir (str): Dir of txt files to save detected bboxes
|
|
results.
|
|
out_quadrangel_txt_dir (str): Dir of txt files to save
|
|
quadrangel results.
|
|
score_thr (float, optional): Score threshold for bboxes.
|
|
save_score (bool, optional): Whether to save score at each line end
|
|
to search best threshold when evaluating.
|
|
"""
|
|
assert bboxes_with_scores.ndim == 2
|
|
assert bboxes_with_scores.shape[1] == 5 or bboxes_with_scores.shape[1] == 9
|
|
assert qudrangels_with_scores.ndim == 2
|
|
assert qudrangels_with_scores.shape[1] == 9
|
|
assert bboxes_with_scores.shape[0] >= qudrangels_with_scores.shape[0]
|
|
assert isinstance(img_name, str)
|
|
assert isinstance(out_bbox_txt_dir, str)
|
|
assert isinstance(out_quadrangel_txt_dir, str)
|
|
assert isinstance(score_thr, float)
|
|
assert score_thr >= 0 and score_thr < 1
|
|
|
|
# filter out invalid results
|
|
initial_valid_bboxes, valid_bbox_scores = filter_result(
|
|
bboxes_with_scores[:, :-1], bboxes_with_scores[:, -1], score_thr)
|
|
if initial_valid_bboxes.shape[1] == 4:
|
|
valid_bboxes = np.ndarray(
|
|
(initial_valid_bboxes.shape[0], 8)).astype(int)
|
|
idx_list = [0, 1, 2, 1, 2, 3, 0, 3]
|
|
for i in range(8):
|
|
valid_bboxes[:, i] = initial_valid_bboxes[:, idx_list[i]]
|
|
|
|
elif initial_valid_bboxes.shape[1] == 8:
|
|
valid_bboxes = initial_valid_bboxes
|
|
|
|
valid_quadrangels, valid_quadrangel_scores = filter_result(
|
|
qudrangels_with_scores[:, :-1], qudrangels_with_scores[:, -1],
|
|
score_thr)
|
|
|
|
# gen target file path
|
|
bbox_txt_file = gen_target_path(out_bbox_txt_dir, img_name, '.txt')
|
|
quadrangel_txt_file = gen_target_path(out_quadrangel_txt_dir, img_name,
|
|
'.txt')
|
|
|
|
# save txt
|
|
if save_score:
|
|
valid_bboxes = np.concatenate(
|
|
(valid_bboxes, valid_bbox_scores.reshape(-1, 1)), axis=1)
|
|
valid_quadrangels = np.concatenate(
|
|
(valid_quadrangels, valid_quadrangel_scores.reshape(-1, 1)),
|
|
axis=1)
|
|
|
|
save_2darray(valid_bboxes, bbox_txt_file)
|
|
save_2darray(valid_quadrangels, quadrangel_txt_file)
|
|
|
|
|
|
def main():
|
|
parser = ArgumentParser()
|
|
parser.add_argument('config', type=str, help='Config file')
|
|
parser.add_argument('checkpoint', type=str, help='Checkpoint file')
|
|
parser.add_argument('img_root', type=str, help='Image root path')
|
|
parser.add_argument('img_list', type=str, help='Image path list file')
|
|
|
|
parser.add_argument(
|
|
'--score-thr', type=float, default=0.5, help='Bbox score threshold')
|
|
parser.add_argument(
|
|
'--out-dir',
|
|
type=str,
|
|
default='./results',
|
|
help='Dir to save '
|
|
'visualize images '
|
|
'and bbox')
|
|
args = parser.parse_args()
|
|
|
|
assert args.score_thr > 0 and args.score_thr < 1
|
|
|
|
# build the model from a config file and a checkpoint file
|
|
device = 'cuda:' + str(torch.cuda.current_device())
|
|
model = init_detector(args.config, args.checkpoint, device=device)
|
|
if hasattr(model, 'module'):
|
|
model = model.module
|
|
if model.cfg.data.test['type'] == 'ConcatDataset':
|
|
model.cfg.data.test.pipeline = \
|
|
model.cfg.data.test['datasets'][0].pipeline
|
|
|
|
# Start Inference
|
|
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
|
|
mmcv.mkdir_or_exist(out_vis_dir)
|
|
|
|
total_img_num = sum([1 for _ in open(args.img_list)])
|
|
progressbar = ProgressBar(task_num=total_img_num)
|
|
with codecs.open(args.img_list, 'r', 'utf-8') as fr:
|
|
for line in fr:
|
|
progressbar.update()
|
|
img_path = args.img_root + '/' + line.strip()
|
|
if not osp.exists(img_path):
|
|
raise FileNotFoundError(img_path)
|
|
# Test a single image
|
|
result = inference_detector(model, img_path)
|
|
img_name = osp.basename(img_path)
|
|
out_file = osp.join(out_vis_dir, img_name)
|
|
kwargs_dict = {
|
|
'score_thr': args.score_thr,
|
|
'show': False,
|
|
'out_file': out_file
|
|
}
|
|
model.show_result(img_path, result, **kwargs_dict)
|
|
|
|
print(f'\nInference done, and results saved in {args.out_dir}\n')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|