mirror of https://github.com/open-mmlab/mmocr.git
* fix #279: save detect results * rename * set device as arg * rm bash filepull/287/head
parent
3bfbb2b619
commit
87a7dcee0a
|
@ -36,10 +36,10 @@ The predicted result will be saved as `demo/output.jpg`.
|
|||
|
||||
```shell
|
||||
# for text detection
|
||||
sh tools/test_imgs.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${IMG_ROOT_PATH} ${IMG_LIST} ${RESULTS_DIR}
|
||||
./tools/det_test_imgs.py ${IMG_ROOT_PATH} ${IMG_LIST} ${CONFIG_FILE} ${CHECKPOINT_FILE} --out-dir ${RESULTS_DIR}
|
||||
|
||||
# for text recognition
|
||||
sh tools/ocr_test_imgs.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${IMG_ROOT_PATH} ${IMG_LIST} ${RESULTS_DIR}
|
||||
./tools/recog_test_imgs.py ${IMG_ROOT_PATH} ${IMG_LIST} ${CONFIG_FILE} ${CHECKPOINT_FILE} --out-dir ${RESULTS_DIR}
|
||||
```
|
||||
It will save both the prediction results and visualized images to `${RESULTS_DIR}`
|
||||
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
#!/usr/bin/env python
|
||||
import os.path as osp
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import mmcv
|
||||
from mmcv.utils import ProgressBar
|
||||
|
||||
from mmdet.apis import inference_detector, init_detector
|
||||
from mmocr.models import build_detector # noqa: F401
|
||||
from mmocr.utils import list_from_file, list_to_file
|
||||
|
||||
|
||||
def gen_target_path(target_root_path, src_name, suffix):
|
||||
"""Gen target file path.
|
||||
|
||||
Args:
|
||||
target_root_path (str): The target root path.
|
||||
src_name (str): The source file name.
|
||||
suffix (str): The suffix of target file.
|
||||
"""
|
||||
assert isinstance(target_root_path, str)
|
||||
assert isinstance(src_name, str)
|
||||
assert isinstance(suffix, str)
|
||||
|
||||
file_name = osp.split(src_name)[-1]
|
||||
name = osp.splitext(file_name)[0]
|
||||
return osp.join(target_root_path, name + suffix)
|
||||
|
||||
|
||||
def save_results(result, out_dir, img_name, score_thr=0.3):
|
||||
"""Save result of detected bounding boxes (quadrangle or polygon) to txt
|
||||
file.
|
||||
|
||||
Args:
|
||||
result (dict): Text Detection result for one image.
|
||||
img_name (str): Image file name.
|
||||
out_dir (str): Dir of txt files to save detected results.
|
||||
score_thr (float, optional): Score threshold to filter bboxes.
|
||||
"""
|
||||
assert 'boundary_result' in result
|
||||
assert score_thr > 0 and score_thr < 1
|
||||
|
||||
txt_file = gen_target_path(out_dir, img_name, '.txt')
|
||||
valid_boundary_res = [
|
||||
res for res in result['boundary_result'] if res[-1] > score_thr
|
||||
]
|
||||
lines = [
|
||||
','.join([str(round(x)) for x in row]) for row in valid_boundary_res
|
||||
]
|
||||
list_to_file(txt_file, lines)
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('img_root', type=str, help='Image root path')
|
||||
parser.add_argument('img_list', type=str, help='Image path list file')
|
||||
parser.add_argument('config', type=str, help='Config file')
|
||||
parser.add_argument('checkpoint', type=str, help='Checkpoint file')
|
||||
parser.add_argument(
|
||||
'--score-thr', type=float, default=0.5, help='Bbox score threshold')
|
||||
parser.add_argument(
|
||||
'--out-dir',
|
||||
type=str,
|
||||
default='./results',
|
||||
help='Dir to save '
|
||||
'visualize images '
|
||||
'and bbox')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference.')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert 0 < args.score_thr < 1
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_detector(args.config, args.checkpoint, device=args.device)
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
||||
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
||||
0].pipeline
|
||||
|
||||
# Start Inference
|
||||
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
|
||||
mmcv.mkdir_or_exist(out_vis_dir)
|
||||
out_txt_dir = osp.join(args.out_dir, 'out_txt_dir')
|
||||
mmcv.mkdir_or_exist(out_txt_dir)
|
||||
|
||||
total_img_num = sum([1 for _ in open(args.img_list)])
|
||||
progressbar = ProgressBar(task_num=total_img_num)
|
||||
for line in list_from_file(args.img_list):
|
||||
progressbar.update()
|
||||
img_path = osp.join(args.img_root, line.strip())
|
||||
if not osp.exists(img_path):
|
||||
raise FileNotFoundError(img_path)
|
||||
# Test a single image
|
||||
result = inference_detector(model, img_path)
|
||||
img_name = osp.basename(img_path)
|
||||
# save result
|
||||
save_results(result, out_txt_dir, img_name, score_thr=args.score_thr)
|
||||
# show result
|
||||
out_file = osp.join(out_vis_dir, img_name)
|
||||
kwargs_dict = {
|
||||
'score_thr': args.score_thr,
|
||||
'show': False,
|
||||
'out_file': out_file
|
||||
}
|
||||
model.show_result(img_path, result, **kwargs_dict)
|
||||
|
||||
print(f'\nInference done, and results saved in {args.out_dir}\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,25 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
DATE=`date +%Y-%m-%d`
|
||||
TIME=`date +"%H-%M-%S"`
|
||||
|
||||
if [ $# -lt 5 ]
|
||||
then
|
||||
echo "Usage: bash $0 CONFIG CHECKPOINT IMG_PREFIX IMG_LIST RESULTS_DIR"
|
||||
exit
|
||||
fi
|
||||
|
||||
CONFIG_FILE=$1
|
||||
CHECKPOINT=$2
|
||||
IMG_ROOT_PATH=$3
|
||||
IMG_LIST=$4
|
||||
OUT_DIR=$5_${DATE}_${TIME}
|
||||
|
||||
mkdir ${OUT_DIR} -p &&
|
||||
|
||||
python tools/ocr_test_imgs.py \
|
||||
--img_root_path ${IMG_ROOT_PATH} \
|
||||
--img_list ${IMG_LIST} \
|
||||
--config ${CONFIG_FILE} \
|
||||
--checkpoint ${CHECKPOINT} \
|
||||
--out_dir ${OUT_DIR}
|
|
@ -6,7 +6,6 @@ from argparse import ArgumentParser
|
|||
from itertools import compress
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
from mmcv.utils import ProgressBar
|
||||
|
||||
from mmdet.apis import init_detector
|
||||
|
@ -40,14 +39,16 @@ def save_results(img_paths, pred_labels, gt_labels, res_dir):
|
|||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--img_root_path', type=str, help='Image root path')
|
||||
parser.add_argument('--img_list', type=str, help='Image path list file')
|
||||
parser.add_argument('--config', type=str, help='Config file')
|
||||
parser.add_argument('--checkpoint', type=str, help='Checkpoint file')
|
||||
parser.add_argument('img_root_path', type=str, help='Image root path')
|
||||
parser.add_argument('img_list', type=str, help='Image path list file')
|
||||
parser.add_argument('config', type=str, help='Config file')
|
||||
parser.add_argument('checkpoint', type=str, help='Checkpoint file')
|
||||
parser.add_argument(
|
||||
'--out_dir', type=str, default='./results', help='Dir to save results')
|
||||
parser.add_argument(
|
||||
'--show', action='store_true', help='show image or save')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# init the logger before other steps
|
||||
|
@ -56,8 +57,7 @@ def main():
|
|||
logger = get_root_logger(log_file=log_file, log_level='INFO')
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
device = 'cuda:' + str(torch.cuda.current_device())
|
||||
model = init_detector(args.config, args.checkpoint, device=device)
|
||||
model = init_detector(args.config, args.checkpoint, device=args.device)
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
|
@ -1,165 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
import os.path as osp
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.utils import ProgressBar
|
||||
|
||||
from mmdet.apis import inference_detector, init_detector
|
||||
from mmocr.core.evaluation.utils import filter_result
|
||||
from mmocr.models import build_detector # noqa: F401
|
||||
from mmocr.utils import list_from_file, list_to_file
|
||||
|
||||
|
||||
def gen_target_path(target_root_path, src_name, suffix):
|
||||
"""Gen target file path.
|
||||
|
||||
Args:
|
||||
target_root_path (str): The target root path.
|
||||
src_name (str): The source file name.
|
||||
suffix (str): The suffix of target file.
|
||||
"""
|
||||
assert isinstance(target_root_path, str)
|
||||
assert isinstance(src_name, str)
|
||||
assert isinstance(suffix, str)
|
||||
|
||||
file_name = osp.split(src_name)[-1]
|
||||
name = osp.splitext(file_name)[0]
|
||||
return osp.join(target_root_path, name + suffix)
|
||||
|
||||
|
||||
def save_2darray(mat, file_name):
|
||||
"""Save 2d array to txt file.
|
||||
|
||||
Args:
|
||||
mat (ndarray): 2d-array of shape (n, m).
|
||||
file_name (str): The output file name.
|
||||
"""
|
||||
lines = [','.join([str(x) for x in row]) for row in mat]
|
||||
list_to_file(file_name, lines)
|
||||
|
||||
|
||||
def save_bboxes_quadrangles(bboxes_with_scores,
|
||||
quadrangles_with_scores,
|
||||
img_name,
|
||||
out_bbox_txt_dir,
|
||||
out_quadrangle_txt_dir,
|
||||
score_thr=0.3,
|
||||
save_score=True):
|
||||
"""Save results of detected bounding boxes and quadrangles to txt file.
|
||||
|
||||
Args:
|
||||
bboxes_with_scores (ndarray): Detected bboxes of shape (n,5).
|
||||
quadrangles_with_scores (ndarray): Detected quadrangles of shape (n,9).
|
||||
img_name (str): Image file name.
|
||||
out_bbox_txt_dir (str): Dir of txt files to save detected bboxes
|
||||
results.
|
||||
out_quadrangle_txt_dir (str): Dir of txt files to save
|
||||
quadrangle results.
|
||||
score_thr (float, optional): Score threshold for bboxes.
|
||||
save_score (bool, optional): Whether to save score at each line end
|
||||
to search best threshold when evaluating.
|
||||
"""
|
||||
assert bboxes_with_scores.ndim == 2
|
||||
assert bboxes_with_scores.shape[1] == 5 or bboxes_with_scores.shape[1] == 9
|
||||
assert quadrangles_with_scores.ndim == 2
|
||||
assert quadrangles_with_scores.shape[1] == 9
|
||||
assert bboxes_with_scores.shape[0] >= quadrangles_with_scores.shape[0]
|
||||
assert isinstance(img_name, str)
|
||||
assert isinstance(out_bbox_txt_dir, str)
|
||||
assert isinstance(out_quadrangle_txt_dir, str)
|
||||
assert isinstance(score_thr, float)
|
||||
assert score_thr >= 0 and score_thr < 1
|
||||
|
||||
# filter out invalid results
|
||||
initial_valid_bboxes, valid_bbox_scores = filter_result(
|
||||
bboxes_with_scores[:, :-1], bboxes_with_scores[:, -1], score_thr)
|
||||
if initial_valid_bboxes.shape[1] == 4:
|
||||
valid_bboxes = np.ndarray(
|
||||
(initial_valid_bboxes.shape[0], 8)).astype(int)
|
||||
idx_list = [0, 1, 2, 1, 2, 3, 0, 3]
|
||||
for i in range(8):
|
||||
valid_bboxes[:, i] = initial_valid_bboxes[:, idx_list[i]]
|
||||
|
||||
elif initial_valid_bboxes.shape[1] == 8:
|
||||
valid_bboxes = initial_valid_bboxes
|
||||
|
||||
valid_quadrangles, valid_quadrangle_scores = filter_result(
|
||||
quadrangles_with_scores[:, :-1], quadrangles_with_scores[:, -1],
|
||||
score_thr)
|
||||
|
||||
# gen target file path
|
||||
bbox_txt_file = gen_target_path(out_bbox_txt_dir, img_name, '.txt')
|
||||
quadrangle_txt_file = gen_target_path(out_quadrangle_txt_dir, img_name,
|
||||
'.txt')
|
||||
|
||||
# save txt
|
||||
if save_score:
|
||||
valid_bboxes = np.concatenate(
|
||||
(valid_bboxes, valid_bbox_scores.reshape(-1, 1)), axis=1)
|
||||
valid_quadrangles = np.concatenate(
|
||||
(valid_quadrangles, valid_quadrangle_scores.reshape(-1, 1)),
|
||||
axis=1)
|
||||
|
||||
save_2darray(valid_bboxes, bbox_txt_file)
|
||||
save_2darray(valid_quadrangles, quadrangle_txt_file)
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('config', type=str, help='Config file')
|
||||
parser.add_argument('checkpoint', type=str, help='Checkpoint file')
|
||||
parser.add_argument('img_root', type=str, help='Image root path')
|
||||
parser.add_argument('img_list', type=str, help='Image path list file')
|
||||
|
||||
parser.add_argument(
|
||||
'--score-thr', type=float, default=0.5, help='Bbox score threshold')
|
||||
parser.add_argument(
|
||||
'--out-dir',
|
||||
type=str,
|
||||
default='./results',
|
||||
help='Dir to save '
|
||||
'visualize images '
|
||||
'and bbox')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.score_thr > 0 and args.score_thr < 1
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
device = 'cuda:' + str(torch.cuda.current_device())
|
||||
model = init_detector(args.config, args.checkpoint, device=device)
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
||||
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
||||
0].pipeline
|
||||
|
||||
# Start Inference
|
||||
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
|
||||
mmcv.mkdir_or_exist(out_vis_dir)
|
||||
|
||||
total_img_num = sum([1 for _ in open(args.img_list)])
|
||||
progressbar = ProgressBar(task_num=total_img_num)
|
||||
for line in list_from_file(args.img_list):
|
||||
progressbar.update()
|
||||
img_path = osp.join(args.img_root, line.strip())
|
||||
if not osp.exists(img_path):
|
||||
raise FileNotFoundError(img_path)
|
||||
# Test a single image
|
||||
result = inference_detector(model, img_path)
|
||||
img_name = osp.basename(img_path)
|
||||
out_file = osp.join(out_vis_dir, img_name)
|
||||
kwargs_dict = {
|
||||
'score_thr': args.score_thr,
|
||||
'show': False,
|
||||
'out_file': out_file
|
||||
}
|
||||
model.show_result(img_path, result, **kwargs_dict)
|
||||
|
||||
print(f'\nInference done, and results saved in {args.out_dir}\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,23 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
DATE=`date +%Y-%m-%d`
|
||||
TIME=`date +"%H-%M-%S"`
|
||||
|
||||
if [ $# -lt 5 ]
|
||||
then
|
||||
echo "Usage: bash $0 CONFIG CHECKPOINT IMG_ROOT_PATH IMG_LIST OUT_DIR"
|
||||
exit
|
||||
fi
|
||||
|
||||
CONFIG_FILE=$1
|
||||
CHECKPOINT=$2
|
||||
IMG_ROOT_PATH=$3
|
||||
IMG_LIST=$4
|
||||
OUT_DIR=$5
|
||||
|
||||
mkdir ${OUT_DIR} -p &&
|
||||
|
||||
|
||||
python tools/test_imgs.py \
|
||||
${CONFIG_FILE} ${CHECKPOINT} ${IMG_ROOT_PATH} ${IMG_LIST} \
|
||||
--out-dir ${OUT_DIR}
|
Loading…
Reference in New Issue