mmocr/tools/ocr_test_imgs.py

133 lines
5.0 KiB
Python

import os.path as osp
import shutil
import time
from argparse import ArgumentParser
import mmcv
import torch
from mmcv.utils import ProgressBar
from mmdet.apis import init_detector
from mmdet.utils import get_root_logger
from mmocr.apis import model_inference
from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
from mmocr.datasets import build_dataset # noqa: F401
from mmocr.models import build_detector # noqa: F401
def save_results(img_paths, pred_labels, gt_labels, res_dir):
"""Save predicted results to txt file.
Args:
img_paths (list[str])
pred_labels (list[str])
gt_labels (list[str])
res_dir (str)
"""
assert len(img_paths) == len(pred_labels) == len(gt_labels)
res_file = osp.join(res_dir, 'results.txt')
correct_file = osp.join(res_dir, 'correct.txt')
wrong_file = osp.join(res_dir, 'wrong.txt')
with open(res_file, 'w') as fw, \
open(correct_file, 'w') as fw_correct, \
open(wrong_file, 'w') as fw_wrong:
for img_path, pred_label, gt_label in zip(img_paths, pred_labels,
gt_labels):
fw.write(img_path + ' ' + pred_label + ' ' + gt_label + '\n')
if pred_label == gt_label:
fw_correct.write(img_path + ' ' + pred_label + ' ' + gt_label +
'\n')
else:
fw_wrong.write(img_path + ' ' + pred_label + ' ' + gt_label +
'\n')
def main():
parser = ArgumentParser()
parser.add_argument('--img_root_path', type=str, help='Image root path')
parser.add_argument('--img_list', type=str, help='Image path list file')
parser.add_argument('--config', type=str, help='Config file')
parser.add_argument('--checkpoint', type=str, help='Checkpoint file')
parser.add_argument(
'--out_dir', type=str, default='./results', help='Dir to save results')
parser.add_argument(
'--show', action='store_true', help='show image or save')
args = parser.parse_args()
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(args.out_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level='INFO')
# build the model from a config file and a checkpoint file
device = 'cuda:' + str(torch.cuda.current_device())
model = init_detector(args.config, args.checkpoint, device=device)
if hasattr(model, 'module'):
model = model.module
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = \
model.cfg.data.test['datasets'][0].pipeline
# Start Inference
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
mmcv.mkdir_or_exist(out_vis_dir)
correct_vis_dir = osp.join(args.out_dir, 'correct')
mmcv.mkdir_or_exist(correct_vis_dir)
wrong_vis_dir = osp.join(args.out_dir, 'wrong')
mmcv.mkdir_or_exist(wrong_vis_dir)
img_paths, pred_labels, gt_labels = [], [], []
total_img_num = sum([1 for _ in open(args.img_list)])
progressbar = ProgressBar(task_num=total_img_num)
num_gt_label = 0
with open(args.img_list, 'r') as fr:
for line in fr:
progressbar.update()
item_list = line.strip().split()
img_file = item_list[0]
gt_label = ''
if len(item_list) >= 2:
gt_label = item_list[1]
num_gt_label += 1
img_path = osp.join(args.img_root_path, img_file)
if not osp.exists(img_path):
raise FileNotFoundError(img_path)
# Test a single image
result = model_inference(model, img_path)
pred_label = result['text']
out_img_name = '_'.join(img_file.split('/'))
out_file = osp.join(out_vis_dir, out_img_name)
kwargs_dict = {
'gt_label': gt_label,
'show': args.show,
'out_file': '' if args.show else out_file
}
model.show_result(img_path, result, **kwargs_dict)
if gt_label != '':
if gt_label == pred_label:
dst_file = osp.join(correct_vis_dir, out_img_name)
else:
dst_file = osp.join(wrong_vis_dir, out_img_name)
shutil.copy(out_file, dst_file)
img_paths.append(img_path)
gt_labels.append(gt_label)
pred_labels.append(pred_label)
# Save results
save_results(img_paths, pred_labels, gt_labels, args.out_dir)
if num_gt_label == len(pred_labels):
# eval
eval_results = eval_ocr_metric(pred_labels, gt_labels)
logger.info('\n' + '-' * 100)
info = 'eval on testset with img_root_path ' + \
f'{args.img_root_path} and img_list {args.img_list}\n'
logger.info(info)
logger.info(eval_results)
print(f'\nInference done, and results saved in {args.out_dir}\n')
if __name__ == '__main__':
main()