mirror of https://github.com/open-mmlab/mmocr.git
Add list_from_file and list_to_file (#226)
* Add list_from_file and list_to_file Signed-off-by: lizz <lizz@sensetime.com> * Add test list_to_file and list_from_file * more * Fix testspull/234/head
parent
17aa9ecc7f
commit
b10b6408ef
|
@ -1,15 +1,14 @@
|
|||
import copy
|
||||
from os import path as osp
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmdet.datasets.builder import DATASETS
|
||||
from mmocr.core import compute_f1_score
|
||||
from mmocr.datasets.base_dataset import BaseDataset
|
||||
from mmocr.datasets.pipelines import sort_vertex8
|
||||
from mmocr.utils import is_type_list, list_from_file
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -52,7 +51,7 @@ class KIEDataset(BaseDataset):
|
|||
'': 0,
|
||||
**{
|
||||
line.rstrip('\r\n'): ind
|
||||
for ind, line in enumerate(mmcv.list_from_file(dict_file), 1)
|
||||
for ind, line in enumerate(list_from_file(dict_file), 1)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -79,7 +78,7 @@ class KIEDataset(BaseDataset):
|
|||
box_num * (box_num + 1).
|
||||
"""
|
||||
|
||||
assert utils.is_type_list(annotations, dict)
|
||||
assert is_type_list(annotations, dict)
|
||||
assert len(annotations) > 0, 'Please remove data with empty annotation'
|
||||
assert 'box' in annotations[0]
|
||||
assert 'text' in annotations[0]
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
|
||||
from mmocr.datasets.builder import LOADERS, build_parser
|
||||
from mmocr.utils import list_from_file
|
||||
|
||||
|
||||
@LOADERS.register_module()
|
||||
|
@ -60,7 +59,7 @@ class HardDiskLoader(Loader):
|
|||
"""
|
||||
|
||||
def _load(self, ann_file):
|
||||
return mmcv.list_from_file(ann_file)
|
||||
return list_from_file(ann_file)
|
||||
|
||||
|
||||
@LOADERS.register_module()
|
||||
|
|
|
@ -8,6 +8,7 @@ from mmdet.core import bbox2roi
|
|||
from mmdet.models.builder import DETECTORS, build_roi_extractor
|
||||
from mmdet.models.detectors import SingleStageDetector
|
||||
from mmocr.core import imshow_edge_node
|
||||
from mmocr.utils import list_from_file
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
|
@ -126,11 +127,9 @@ class SDMGR(SingleStageDetector):
|
|||
|
||||
idx_to_cls = {}
|
||||
if self.class_list is not None:
|
||||
with open(self.class_list, 'r') as fr:
|
||||
for line in fr:
|
||||
line = line.strip().split()
|
||||
class_idx, class_label = line
|
||||
idx_to_cls[class_idx] = class_label
|
||||
for line in list_from_file(self.class_list):
|
||||
class_idx, class_label = line.strip().split()
|
||||
idx_to_cls[class_idx] = class_label
|
||||
|
||||
# if out_file specified, do not show image in window
|
||||
if out_file is not None:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from mmocr.models.builder import CONVERTORS
|
||||
from mmocr.utils import list_from_file
|
||||
|
||||
|
||||
@CONVERTORS.register_module()
|
||||
|
@ -27,11 +28,10 @@ class BaseConvertor:
|
|||
assert dict_list is None or isinstance(dict_list, list)
|
||||
self.idx2char = []
|
||||
if dict_file is not None:
|
||||
with open(dict_file, encoding='utf-8') as fr:
|
||||
for line in fr:
|
||||
line = line.strip()
|
||||
if line != '':
|
||||
self.idx2char.append(line)
|
||||
for line in list_from_file(dict_file):
|
||||
line = line.strip()
|
||||
if line != '':
|
||||
self.idx2char.append(line)
|
||||
elif dict_list is not None:
|
||||
self.idx2char = dict_list
|
||||
else:
|
||||
|
|
|
@ -4,6 +4,7 @@ from .check_argument import (equal_len, is_2dlist, is_3dlist, is_ndarray_list,
|
|||
is_none_or_type, is_type_list, valid_boundary)
|
||||
from .collect_env import collect_env
|
||||
from .data_convert_util import convert_annotations
|
||||
from .fileio import list_from_file, list_to_file
|
||||
from .img_util import drop_orientation, is_not_png
|
||||
from .lmdb_util import lmdb_converter
|
||||
from .logger import get_root_logger
|
||||
|
@ -12,5 +13,6 @@ __all__ = [
|
|||
'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env',
|
||||
'is_3dlist', 'is_ndarray_list', 'is_type_list', 'is_none_or_type',
|
||||
'equal_len', 'is_2dlist', 'valid_boundary', 'lmdb_converter',
|
||||
'drop_orientation', 'convert_annotations', 'is_not_png'
|
||||
'drop_orientation', 'convert_annotations', 'is_not_png', 'list_to_file',
|
||||
'list_from_file'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
def list_to_file(filename, lines):
|
||||
"""Write a list of strings to a text file.
|
||||
|
||||
Args:
|
||||
filename (str): The output filename. It will be created/overwritten.
|
||||
lines (list(str)): Data to be written.
|
||||
"""
|
||||
with open(filename, 'w', encoding='utf-8') as fw:
|
||||
for line in lines:
|
||||
fw.write(f'{line}\n')
|
||||
|
||||
|
||||
def list_from_file(filename, encoding='utf-8'):
|
||||
"""Load a text file and parse the content as a list of strings. The
|
||||
trailing "\\r" and "\\n" of each line will be removed.
|
||||
|
||||
Note:
|
||||
This will be replaced by mmcv's version after it supports encoding.
|
||||
|
||||
Args:
|
||||
filename (str): Filename.
|
||||
encoding (str): Encoding used to open the file. Default utf-8.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of strings.
|
||||
"""
|
||||
item_list = []
|
||||
with open(filename, 'r', encoding=encoding) as f:
|
||||
for line in f:
|
||||
item_list.append(line.rstrip('\n\r'))
|
||||
return item_list
|
|
@ -5,11 +5,12 @@ from pathlib import Path
|
|||
|
||||
import lmdb
|
||||
|
||||
from mmocr.utils import list_from_file
|
||||
|
||||
def lmdb_converter(img_list, output, batch_size=1000, coding='utf-8'):
|
||||
# read img_list
|
||||
with open(img_list) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
def lmdb_converter(img_list_file, output, batch_size=1000, coding='utf-8'):
|
||||
# read img_list_file
|
||||
lines = list_from_file(img_list_file)
|
||||
|
||||
# create lmdb database
|
||||
if Path(output).is_dir():
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
import tempfile
|
||||
|
||||
from mmocr.utils import list_from_file, list_to_file
|
||||
|
||||
lists = [
|
||||
[],
|
||||
[' '],
|
||||
['\t'],
|
||||
['a'],
|
||||
[1],
|
||||
[1.],
|
||||
['a', 'b'],
|
||||
['a', 1, 1.],
|
||||
[1, 1., 'a'],
|
||||
['啊', '啊啊'],
|
||||
['選択', 'noël', 'Информацией', 'ÄÆä'],
|
||||
]
|
||||
|
||||
|
||||
def test_list_to_file():
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
for i, lines in enumerate(lists):
|
||||
filename = f'{tmpdirname}/{i}.txt'
|
||||
list_to_file(filename, lines)
|
||||
lines2 = [
|
||||
line.rstrip('\r\n')
|
||||
for line in open(filename, 'r', encoding='utf-8').readlines()
|
||||
]
|
||||
lines = list(map(str, lines))
|
||||
assert len(lines) == len(lines2)
|
||||
assert all(line1 == line2 for line1, line2 in zip(lines, lines2))
|
||||
|
||||
|
||||
def test_list_from_file():
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
for encoding in ['utf-8', 'utf-8-sig']:
|
||||
for lineend in ['\n', '\r\n']:
|
||||
for i, lines in enumerate(lists):
|
||||
filename = f'{tmpdirname}/{i}.txt'
|
||||
with open(filename, 'w', encoding=encoding) as f:
|
||||
f.writelines(f'{line}{lineend}' for line in lines)
|
||||
lines2 = list_from_file(filename, encoding=encoding)
|
||||
lines = list(map(str, lines))
|
||||
assert len(lines) == len(lines2)
|
||||
assert all(line1 == line2
|
||||
for line1, line2 in zip(lines, lines2))
|
|
@ -1,16 +1,13 @@
|
|||
import argparse
|
||||
import codecs
|
||||
import json
|
||||
|
||||
import mmcv
|
||||
|
||||
def read_json(fpath):
|
||||
with codecs.open(fpath, 'r', 'utf-8') as f:
|
||||
obj = json.load(f)
|
||||
return obj
|
||||
from mmocr.utils import list_to_file
|
||||
|
||||
|
||||
def parse_coco_json(in_path):
|
||||
json_obj = read_json(in_path)
|
||||
json_obj = mmcv.load(in_path)
|
||||
image_infos = json_obj['images']
|
||||
annotations = json_obj['annotations']
|
||||
imgid2imgname = {}
|
||||
|
@ -35,18 +32,17 @@ def parse_coco_json(in_path):
|
|||
|
||||
|
||||
def gen_line_dict_file(out_path, imgid2imgname, imgid2anno):
|
||||
# import pdb; pdb.set_trace()
|
||||
with codecs.open(out_path, 'w', 'utf-8') as fw:
|
||||
for key, value in imgid2imgname.items():
|
||||
if key in imgid2anno:
|
||||
anno = imgid2anno[key]
|
||||
line_dict = {}
|
||||
line_dict['file_name'] = value['file_name']
|
||||
line_dict['height'] = value['height']
|
||||
line_dict['width'] = value['width']
|
||||
line_dict['annotations'] = anno
|
||||
line_dict_str = json.dumps(line_dict)
|
||||
fw.write(line_dict_str + '\n')
|
||||
lines = []
|
||||
for key, value in imgid2imgname.items():
|
||||
if key in imgid2anno:
|
||||
anno = imgid2anno[key]
|
||||
line_dict = {}
|
||||
line_dict['file_name'] = value['file_name']
|
||||
line_dict['height'] = value['height']
|
||||
line_dict['width'] = value['width']
|
||||
line_dict['annotations'] = anno
|
||||
lines.append(json.dumps(line_dict))
|
||||
list_to_file(out_path, lines)
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
|
|
@ -8,7 +8,8 @@ import mmcv
|
|||
import numpy as np
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
from mmocr.utils import convert_annotations, drop_orientation, is_not_png
|
||||
from mmocr.utils import (convert_annotations, drop_orientation, is_not_png,
|
||||
list_from_file)
|
||||
|
||||
|
||||
def collect_files(img_dir, gt_dir, split):
|
||||
|
@ -84,11 +85,8 @@ def collect_annotations(files, split, nproc=1):
|
|||
|
||||
|
||||
def load_txt_info(gt_file, img_info):
|
||||
with open(gt_file) as f:
|
||||
gt_list = f.readlines()
|
||||
|
||||
anno_info = []
|
||||
for line in gt_list:
|
||||
for line in list_from_file(gt_file):
|
||||
# each line has one ploygen (n vetices), and one text.
|
||||
# e.g., 695,885,866,888,867,1146,696,1143,####Latin 9
|
||||
line = line.strip()
|
||||
|
|
|
@ -7,7 +7,8 @@ import mmcv
|
|||
import numpy as np
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
from mmocr.utils import convert_annotations, drop_orientation, is_not_png
|
||||
from mmocr.utils import (convert_annotations, drop_orientation, is_not_png,
|
||||
list_from_file)
|
||||
|
||||
|
||||
def collect_files(img_dir, gt_dir):
|
||||
|
@ -96,11 +97,9 @@ def load_img_info(files, dataset):
|
|||
assert img.shape[0:2] == img_color.shape[0:2]
|
||||
|
||||
if dataset == 'icdar2017':
|
||||
with open(gt_file) as f:
|
||||
gt_list = f.readlines()
|
||||
gt_list = list_from_file(gt_file)
|
||||
elif dataset == 'icdar2015':
|
||||
with open(gt_file, mode='r', encoding='utf-8-sig') as f:
|
||||
gt_list = f.readlines()
|
||||
gt_list = list_from_file(gt_file, encoding='utf-8-sig')
|
||||
else:
|
||||
raise NotImplementedError(f'Not support {dataset}')
|
||||
|
||||
|
|
|
@ -4,64 +4,62 @@ import os.path as osp
|
|||
|
||||
import cv2
|
||||
|
||||
from mmocr.utils import list_from_file, list_to_file
|
||||
|
||||
|
||||
def parse_old_label(data_root, in_path, img_size=False):
|
||||
imgid2imgname = {}
|
||||
imgid2anno = {}
|
||||
idx = 0
|
||||
with open(in_path, 'r') as fr:
|
||||
for line in fr:
|
||||
line = line.strip().split()
|
||||
img_full_path = osp.join(data_root, line[0])
|
||||
if not osp.exists(img_full_path):
|
||||
continue
|
||||
ann_file = osp.join(data_root, line[1])
|
||||
if not osp.exists(ann_file):
|
||||
continue
|
||||
for line in list_from_file(in_path):
|
||||
line = line.strip().split()
|
||||
img_full_path = osp.join(data_root, line[0])
|
||||
if not osp.exists(img_full_path):
|
||||
continue
|
||||
ann_file = osp.join(data_root, line[1])
|
||||
if not osp.exists(ann_file):
|
||||
continue
|
||||
|
||||
img_info = {}
|
||||
img_info['file_name'] = line[0]
|
||||
if img_size:
|
||||
img = cv2.imread(img_full_path)
|
||||
h, w = img.shape[:2]
|
||||
img_info['height'] = h
|
||||
img_info['width'] = w
|
||||
imgid2imgname[idx] = img_info
|
||||
img_info = {}
|
||||
img_info['file_name'] = line[0]
|
||||
if img_size:
|
||||
img = cv2.imread(img_full_path)
|
||||
h, w = img.shape[:2]
|
||||
img_info['height'] = h
|
||||
img_info['width'] = w
|
||||
imgid2imgname[idx] = img_info
|
||||
|
||||
imgid2anno[idx] = []
|
||||
char_annos = []
|
||||
with open(ann_file, 'r') as fr:
|
||||
t = 0
|
||||
for line in fr:
|
||||
line = line.strip()
|
||||
if t == 0:
|
||||
img_info['text'] = line
|
||||
else:
|
||||
char_box = [float(x) for x in line.split()]
|
||||
char_text = img_info['text'][t - 1]
|
||||
char_ann = dict(char_box=char_box, char_text=char_text)
|
||||
char_annos.append(char_ann)
|
||||
t += 1
|
||||
imgid2anno[idx] = char_annos
|
||||
idx += 1
|
||||
imgid2anno[idx] = []
|
||||
char_annos = []
|
||||
for t, ann_line in enumerate(list_from_file(ann_file)):
|
||||
ann_line = ann_line.strip()
|
||||
if t == 0:
|
||||
img_info['text'] = ann_line
|
||||
else:
|
||||
char_box = [float(x) for x in ann_line.split()]
|
||||
char_text = img_info['text'][t - 1]
|
||||
char_ann = dict(char_box=char_box, char_text=char_text)
|
||||
char_annos.append(char_ann)
|
||||
imgid2anno[idx] = char_annos
|
||||
idx += 1
|
||||
|
||||
return imgid2imgname, imgid2anno
|
||||
|
||||
|
||||
def gen_line_dict_file(out_path, imgid2imgname, imgid2anno, img_size=False):
|
||||
with open(out_path, 'w', encoding='utf-8') as fw:
|
||||
for key, value in imgid2imgname.items():
|
||||
if key in imgid2anno:
|
||||
anno = imgid2anno[key]
|
||||
line_dict = {}
|
||||
line_dict['file_name'] = value['file_name']
|
||||
line_dict['text'] = value['text']
|
||||
if img_size:
|
||||
line_dict['height'] = value['height']
|
||||
line_dict['width'] = value['width']
|
||||
line_dict['annotations'] = anno
|
||||
line_dict_str = json.dumps(line_dict)
|
||||
fw.write(line_dict_str + '\n')
|
||||
lines = []
|
||||
for key, value in imgid2imgname.items():
|
||||
if key in imgid2anno:
|
||||
anno = imgid2anno[key]
|
||||
line_dict = {}
|
||||
line_dict['file_name'] = value['file_name']
|
||||
line_dict['text'] = value['text']
|
||||
if img_size:
|
||||
line_dict['height'] = value['height']
|
||||
line_dict['width'] = value['width']
|
||||
line_dict['annotations'] = anno
|
||||
lines.append(json.dumps(line_dict))
|
||||
list_to_file(out_path, lines)
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
|
|
@ -5,6 +5,8 @@ import xml.etree.ElementTree as ET
|
|||
|
||||
import cv2
|
||||
|
||||
from mmocr.utils.fileio import list_to_file
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -43,35 +45,35 @@ def main():
|
|||
root = tree.getroot()
|
||||
|
||||
index = 1
|
||||
with open(dst_label_file, 'w', encoding='utf-8') as fw:
|
||||
total_img_num = len(root)
|
||||
i = 1
|
||||
for image_node in root.findall('image'):
|
||||
image_name = image_node.find('imageName').text
|
||||
print(f'[{i}/{total_img_num}] Process image: {image_name}')
|
||||
i += 1
|
||||
lexicon = image_node.find('lex').text.lower()
|
||||
lexicon_list = lexicon.split(',')
|
||||
lex_size = len(lexicon_list)
|
||||
src_img = cv2.imread(osp.join(src_image_root, image_name))
|
||||
for rectangle in image_node.find('taggedRectangles'):
|
||||
x = int(rectangle.get('x'))
|
||||
y = int(rectangle.get('y'))
|
||||
w = int(rectangle.get('width'))
|
||||
h = int(rectangle.get('height'))
|
||||
rb, re = max(0, y), max(0, y + h)
|
||||
cb, ce = max(0, x), max(0, x + w)
|
||||
dst_img = src_img[rb:re, cb:ce]
|
||||
text_label = rectangle.find('tag').text.lower()
|
||||
if args.resize:
|
||||
dst_img = cv2.resize(dst_img, (args.width, args.height))
|
||||
dst_img_name = f'img_{index:04}' + '.jpg'
|
||||
index += 1
|
||||
dst_img_path = osp.join(dst_image_root, dst_img_name)
|
||||
cv2.imwrite(dst_img_path, dst_img)
|
||||
fw.write(f'{osp.basename(dst_image_root)}/{dst_img_name} '
|
||||
f'{text_label} {lex_size} {lexicon}\n')
|
||||
|
||||
lines = []
|
||||
total_img_num = len(root)
|
||||
i = 1
|
||||
for image_node in root.findall('image'):
|
||||
image_name = image_node.find('imageName').text
|
||||
print(f'[{i}/{total_img_num}] Process image: {image_name}')
|
||||
i += 1
|
||||
lexicon = image_node.find('lex').text.lower()
|
||||
lexicon_list = lexicon.split(',')
|
||||
lex_size = len(lexicon_list)
|
||||
src_img = cv2.imread(osp.join(src_image_root, image_name))
|
||||
for rectangle in image_node.find('taggedRectangles'):
|
||||
x = int(rectangle.get('x'))
|
||||
y = int(rectangle.get('y'))
|
||||
w = int(rectangle.get('width'))
|
||||
h = int(rectangle.get('height'))
|
||||
rb, re = max(0, y), max(0, y + h)
|
||||
cb, ce = max(0, x), max(0, x + w)
|
||||
dst_img = src_img[rb:re, cb:ce]
|
||||
text_label = rectangle.find('tag').text.lower()
|
||||
if args.resize:
|
||||
dst_img = cv2.resize(dst_img, (args.width, args.height))
|
||||
dst_img_name = f'img_{index:04}' + '.jpg'
|
||||
index += 1
|
||||
dst_img_path = osp.join(dst_image_root, dst_img_name)
|
||||
cv2.imwrite(dst_img_path, dst_img)
|
||||
lines.append(f'{osp.basename(dst_image_root)}/{dst_img_name} '
|
||||
f'{text_label} {lex_size} {lexicon}')
|
||||
list_to_file(dst_label_file, lines)
|
||||
print(f'Finish to generate svt testset, '
|
||||
f'with label file {dst_label_file}')
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import os.path as osp
|
|||
import shutil
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from itertools import compress
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
|
@ -13,7 +14,7 @@ from mmocr.apis import model_inference
|
|||
from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
|
||||
from mmocr.datasets import build_dataset # noqa: F401
|
||||
from mmocr.models import build_detector # noqa: F401
|
||||
from mmocr.utils import get_root_logger
|
||||
from mmocr.utils import get_root_logger, list_from_file, list_to_file
|
||||
|
||||
|
||||
def save_results(img_paths, pred_labels, gt_labels, res_dir):
|
||||
|
@ -26,21 +27,15 @@ def save_results(img_paths, pred_labels, gt_labels, res_dir):
|
|||
res_dir (str)
|
||||
"""
|
||||
assert len(img_paths) == len(pred_labels) == len(gt_labels)
|
||||
res_file = osp.join(res_dir, 'results.txt')
|
||||
correct_file = osp.join(res_dir, 'correct.txt')
|
||||
wrong_file = osp.join(res_dir, 'wrong.txt')
|
||||
with open(res_file, 'w') as fw, \
|
||||
open(correct_file, 'w') as fw_correct, \
|
||||
open(wrong_file, 'w') as fw_wrong:
|
||||
for img_path, pred_label, gt_label in zip(img_paths, pred_labels,
|
||||
gt_labels):
|
||||
fw.write(img_path + ' ' + pred_label + ' ' + gt_label + '\n')
|
||||
if pred_label == gt_label:
|
||||
fw_correct.write(img_path + ' ' + pred_label + ' ' + gt_label +
|
||||
'\n')
|
||||
else:
|
||||
fw_wrong.write(img_path + ' ' + pred_label + ' ' + gt_label +
|
||||
'\n')
|
||||
corrects = [pred == gt for pred, gt in zip(pred_labels, gt_labels)]
|
||||
wrongs = [not c for c in corrects]
|
||||
lines = [
|
||||
f'{img} {pred} {gt}'
|
||||
for img, pred, gt in zip(img_paths, pred_labels, gt_labels)
|
||||
]
|
||||
list_to_file(osp.join(res_dir, 'results.txt'), lines)
|
||||
list_to_file(osp.join(res_dir, 'correct.txt'), compress(lines, corrects))
|
||||
list_to_file(osp.join(res_dir, 'wrong.txt'), compress(lines, wrongs))
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -80,39 +75,38 @@ def main():
|
|||
total_img_num = sum([1 for _ in open(args.img_list)])
|
||||
progressbar = ProgressBar(task_num=total_img_num)
|
||||
num_gt_label = 0
|
||||
with open(args.img_list, 'r') as fr:
|
||||
for line in fr:
|
||||
progressbar.update()
|
||||
item_list = line.strip().split()
|
||||
img_file = item_list[0]
|
||||
gt_label = ''
|
||||
if len(item_list) >= 2:
|
||||
gt_label = item_list[1]
|
||||
num_gt_label += 1
|
||||
img_path = osp.join(args.img_root_path, img_file)
|
||||
if not osp.exists(img_path):
|
||||
raise FileNotFoundError(img_path)
|
||||
# Test a single image
|
||||
result = model_inference(model, img_path)
|
||||
pred_label = result['text']
|
||||
for line in list_from_file(args.img_list):
|
||||
progressbar.update()
|
||||
item_list = line.strip().split()
|
||||
img_file = item_list[0]
|
||||
gt_label = ''
|
||||
if len(item_list) >= 2:
|
||||
gt_label = item_list[1]
|
||||
num_gt_label += 1
|
||||
img_path = osp.join(args.img_root_path, img_file)
|
||||
if not osp.exists(img_path):
|
||||
raise FileNotFoundError(img_path)
|
||||
# Test a single image
|
||||
result = model_inference(model, img_path)
|
||||
pred_label = result['text']
|
||||
|
||||
out_img_name = '_'.join(img_file.split('/'))
|
||||
out_file = osp.join(out_vis_dir, out_img_name)
|
||||
kwargs_dict = {
|
||||
'gt_label': gt_label,
|
||||
'show': args.show,
|
||||
'out_file': '' if args.show else out_file
|
||||
}
|
||||
model.show_result(img_path, result, **kwargs_dict)
|
||||
if gt_label != '':
|
||||
if gt_label == pred_label:
|
||||
dst_file = osp.join(correct_vis_dir, out_img_name)
|
||||
else:
|
||||
dst_file = osp.join(wrong_vis_dir, out_img_name)
|
||||
shutil.copy(out_file, dst_file)
|
||||
img_paths.append(img_path)
|
||||
gt_labels.append(gt_label)
|
||||
pred_labels.append(pred_label)
|
||||
out_img_name = '_'.join(img_file.split('/'))
|
||||
out_file = osp.join(out_vis_dir, out_img_name)
|
||||
kwargs_dict = {
|
||||
'gt_label': gt_label,
|
||||
'show': args.show,
|
||||
'out_file': '' if args.show else out_file
|
||||
}
|
||||
model.show_result(img_path, result, **kwargs_dict)
|
||||
if gt_label != '':
|
||||
if gt_label == pred_label:
|
||||
dst_file = osp.join(correct_vis_dir, out_img_name)
|
||||
else:
|
||||
dst_file = osp.join(wrong_vis_dir, out_img_name)
|
||||
shutil.copy(out_file, dst_file)
|
||||
img_paths.append(img_path)
|
||||
gt_labels.append(gt_label)
|
||||
pred_labels.append(pred_label)
|
||||
|
||||
# Save results
|
||||
save_results(img_paths, pred_labels, gt_labels, args.out_dir)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
#!/usr/bin/env python
|
||||
import codecs
|
||||
import os.path as osp
|
||||
from argparse import ArgumentParser
|
||||
|
||||
|
@ -11,6 +10,7 @@ from mmcv.utils import ProgressBar
|
|||
from mmdet.apis import inference_detector, init_detector
|
||||
from mmocr.core.evaluation.utils import filter_result
|
||||
from mmocr.models import build_detector # noqa: F401
|
||||
from mmocr.utils import list_from_file, list_to_file
|
||||
|
||||
|
||||
def gen_target_path(target_root_path, src_name, suffix):
|
||||
|
@ -25,9 +25,9 @@ def gen_target_path(target_root_path, src_name, suffix):
|
|||
assert isinstance(src_name, str)
|
||||
assert isinstance(suffix, str)
|
||||
|
||||
dir_name, file_name = osp.split(src_name)
|
||||
name, file_suffix = osp.splitext(file_name)
|
||||
return target_root_path + '/' + name + suffix
|
||||
file_name = osp.split(src_name)[-1]
|
||||
name = osp.splitext(file_name)[0]
|
||||
return osp.join(target_root_path, name + suffix)
|
||||
|
||||
|
||||
def save_2darray(mat, file_name):
|
||||
|
@ -37,10 +37,8 @@ def save_2darray(mat, file_name):
|
|||
mat (ndarray): 2d-array of shape (n, m).
|
||||
file_name (str): The output file name.
|
||||
"""
|
||||
with codecs.open(file_name, 'w', 'utf-8') as fw:
|
||||
for row in mat:
|
||||
row_str = ','.join([str(x) for x in row])
|
||||
fw.write(row_str + '\n')
|
||||
lines = [','.join([str(x) for x in row]) for row in mat]
|
||||
list_to_file(file_name, lines)
|
||||
|
||||
|
||||
def save_bboxes_quadrangles(bboxes_with_scores,
|
||||
|
@ -144,22 +142,21 @@ def main():
|
|||
|
||||
total_img_num = sum([1 for _ in open(args.img_list)])
|
||||
progressbar = ProgressBar(task_num=total_img_num)
|
||||
with codecs.open(args.img_list, 'r', 'utf-8') as fr:
|
||||
for line in fr:
|
||||
progressbar.update()
|
||||
img_path = args.img_root + '/' + line.strip()
|
||||
if not osp.exists(img_path):
|
||||
raise FileNotFoundError(img_path)
|
||||
# Test a single image
|
||||
result = inference_detector(model, img_path)
|
||||
img_name = osp.basename(img_path)
|
||||
out_file = osp.join(out_vis_dir, img_name)
|
||||
kwargs_dict = {
|
||||
'score_thr': args.score_thr,
|
||||
'show': False,
|
||||
'out_file': out_file
|
||||
}
|
||||
model.show_result(img_path, result, **kwargs_dict)
|
||||
for line in list_from_file(args.img_list):
|
||||
progressbar.update()
|
||||
img_path = osp.join(args.img_root, line.strip())
|
||||
if not osp.exists(img_path):
|
||||
raise FileNotFoundError(img_path)
|
||||
# Test a single image
|
||||
result = inference_detector(model, img_path)
|
||||
img_name = osp.basename(img_path)
|
||||
out_file = osp.join(out_vis_dir, img_name)
|
||||
kwargs_dict = {
|
||||
'score_thr': args.score_thr,
|
||||
'show': False,
|
||||
'out_file': out_file
|
||||
}
|
||||
model.show_result(img_path, result, **kwargs_dict)
|
||||
|
||||
print(f'\nInference done, and results saved in {args.out_dir}\n')
|
||||
|
||||
|
|
Loading…
Reference in New Issue