mirror of https://github.com/open-mmlab/mmocr.git
126 lines
4.5 KiB
Python
126 lines
4.5 KiB
Python
#!/usr/bin/env python
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
import shutil
|
|
import time
|
|
from argparse import ArgumentParser
|
|
from itertools import compress
|
|
|
|
import mmcv
|
|
from mmcv.utils import ProgressBar
|
|
|
|
from mmocr.apis import init_detector, model_inference
|
|
from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
|
|
from mmocr.datasets import build_dataset # noqa: F401
|
|
from mmocr.models import build_detector # noqa: F401
|
|
from mmocr.utils import get_root_logger, list_from_file, list_to_file
|
|
|
|
|
|
def save_results(img_paths, pred_labels, gt_labels, res_dir):
|
|
"""Save predicted results to txt file.
|
|
|
|
Args:
|
|
img_paths (list[str])
|
|
pred_labels (list[str])
|
|
gt_labels (list[str])
|
|
res_dir (str)
|
|
"""
|
|
assert len(img_paths) == len(pred_labels) == len(gt_labels)
|
|
corrects = [pred == gt for pred, gt in zip(pred_labels, gt_labels)]
|
|
wrongs = [not c for c in corrects]
|
|
lines = [
|
|
f'{img} {pred} {gt}'
|
|
for img, pred, gt in zip(img_paths, pred_labels, gt_labels)
|
|
]
|
|
list_to_file(osp.join(res_dir, 'results.txt'), lines)
|
|
list_to_file(osp.join(res_dir, 'correct.txt'), compress(lines, corrects))
|
|
list_to_file(osp.join(res_dir, 'wrong.txt'), compress(lines, wrongs))
|
|
|
|
|
|
def main():
|
|
parser = ArgumentParser()
|
|
parser.add_argument('img_root_path', type=str, help='Image root path')
|
|
parser.add_argument('img_list', type=str, help='Image path list file')
|
|
parser.add_argument('config', type=str, help='Config file')
|
|
parser.add_argument('checkpoint', type=str, help='Checkpoint file')
|
|
parser.add_argument(
|
|
'--out_dir', type=str, default='./results', help='Dir to save results')
|
|
parser.add_argument(
|
|
'--show', action='store_true', help='show image or save')
|
|
parser.add_argument(
|
|
'--device', default='cuda:0', help='Device used for inference.')
|
|
args = parser.parse_args()
|
|
|
|
# init the logger before other steps
|
|
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
|
log_file = osp.join(args.out_dir, f'{timestamp}.log')
|
|
logger = get_root_logger(log_file=log_file, log_level='INFO')
|
|
|
|
# build the model from a config file and a checkpoint file
|
|
model = init_detector(args.config, args.checkpoint, device=args.device)
|
|
if hasattr(model, 'module'):
|
|
model = model.module
|
|
|
|
# Start Inference
|
|
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
|
|
mmcv.mkdir_or_exist(out_vis_dir)
|
|
correct_vis_dir = osp.join(args.out_dir, 'correct')
|
|
mmcv.mkdir_or_exist(correct_vis_dir)
|
|
wrong_vis_dir = osp.join(args.out_dir, 'wrong')
|
|
mmcv.mkdir_or_exist(wrong_vis_dir)
|
|
img_paths, pred_labels, gt_labels = [], [], []
|
|
|
|
lines = list_from_file(args.img_list)
|
|
progressbar = ProgressBar(task_num=len(lines))
|
|
num_gt_label = 0
|
|
for line in lines:
|
|
progressbar.update()
|
|
item_list = line.strip().split()
|
|
img_file = item_list[0]
|
|
gt_label = ''
|
|
if len(item_list) >= 2:
|
|
gt_label = item_list[1]
|
|
num_gt_label += 1
|
|
img_path = osp.join(args.img_root_path, img_file)
|
|
if not osp.exists(img_path):
|
|
raise FileNotFoundError(img_path)
|
|
# Test a single image
|
|
result = model_inference(model, img_path)
|
|
pred_label = result['text']
|
|
|
|
out_img_name = '_'.join(img_file.split('/'))
|
|
out_file = osp.join(out_vis_dir, out_img_name)
|
|
kwargs_dict = {
|
|
'gt_label': gt_label,
|
|
'show': args.show,
|
|
'out_file': '' if args.show else out_file
|
|
}
|
|
model.show_result(img_path, result, **kwargs_dict)
|
|
if gt_label != '':
|
|
if gt_label == pred_label:
|
|
dst_file = osp.join(correct_vis_dir, out_img_name)
|
|
else:
|
|
dst_file = osp.join(wrong_vis_dir, out_img_name)
|
|
shutil.copy(out_file, dst_file)
|
|
img_paths.append(img_path)
|
|
gt_labels.append(gt_label)
|
|
pred_labels.append(pred_label)
|
|
|
|
# Save results
|
|
save_results(img_paths, pred_labels, gt_labels, args.out_dir)
|
|
|
|
if num_gt_label == len(pred_labels):
|
|
# eval
|
|
eval_results = eval_ocr_metric(pred_labels, gt_labels)
|
|
logger.info('\n' + '-' * 100)
|
|
info = ('eval on testset with img_root_path '
|
|
f'{args.img_root_path} and img_list {args.img_list}\n')
|
|
logger.info(info)
|
|
logger.info(eval_results)
|
|
|
|
print(f'\nInference done, and results saved in {args.out_dir}\n')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|