Add list_from_file and list_to_file (#226)

* Add list_from_file and list_to_file

Signed-off-by: lizz <lizz@sensetime.com>

* Add test list_to_file and list_from_file

* more

* Fix tests
pull/234/head
lizz 2021-05-24 14:01:42 +08:00 committed by GitHub
parent 17aa9ecc7f
commit b10b6408ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 258 additions and 197 deletions

View File

@ -1,15 +1,14 @@
import copy
from os import path as osp
import mmcv
import numpy as np
import torch
import mmocr.utils as utils
from mmdet.datasets.builder import DATASETS
from mmocr.core import compute_f1_score
from mmocr.datasets.base_dataset import BaseDataset
from mmocr.datasets.pipelines import sort_vertex8
from mmocr.utils import is_type_list, list_from_file
@DATASETS.register_module()
@ -52,7 +51,7 @@ class KIEDataset(BaseDataset):
'': 0,
**{
line.rstrip('\r\n'): ind
for ind, line in enumerate(mmcv.list_from_file(dict_file), 1)
for ind, line in enumerate(list_from_file(dict_file), 1)
}
}
@ -79,7 +78,7 @@ class KIEDataset(BaseDataset):
box_num * (box_num + 1).
"""
assert utils.is_type_list(annotations, dict)
assert is_type_list(annotations, dict)
assert len(annotations) > 0, 'Please remove data with empty annotation'
assert 'box' in annotations[0]
assert 'text' in annotations[0]

View File

@ -1,8 +1,7 @@
import os.path as osp
import mmcv
from mmocr.datasets.builder import LOADERS, build_parser
from mmocr.utils import list_from_file
@LOADERS.register_module()
@ -60,7 +59,7 @@ class HardDiskLoader(Loader):
"""
def _load(self, ann_file):
return mmcv.list_from_file(ann_file)
return list_from_file(ann_file)
@LOADERS.register_module()

View File

@ -8,6 +8,7 @@ from mmdet.core import bbox2roi
from mmdet.models.builder import DETECTORS, build_roi_extractor
from mmdet.models.detectors import SingleStageDetector
from mmocr.core import imshow_edge_node
from mmocr.utils import list_from_file
@DETECTORS.register_module()
@ -126,11 +127,9 @@ class SDMGR(SingleStageDetector):
idx_to_cls = {}
if self.class_list is not None:
with open(self.class_list, 'r') as fr:
for line in fr:
line = line.strip().split()
class_idx, class_label = line
idx_to_cls[class_idx] = class_label
for line in list_from_file(self.class_list):
class_idx, class_label = line.strip().split()
idx_to_cls[class_idx] = class_label
# if out_file specified, do not show image in window
if out_file is not None:

View File

@ -1,4 +1,5 @@
from mmocr.models.builder import CONVERTORS
from mmocr.utils import list_from_file
@CONVERTORS.register_module()
@ -27,11 +28,10 @@ class BaseConvertor:
assert dict_list is None or isinstance(dict_list, list)
self.idx2char = []
if dict_file is not None:
with open(dict_file, encoding='utf-8') as fr:
for line in fr:
line = line.strip()
if line != '':
self.idx2char.append(line)
for line in list_from_file(dict_file):
line = line.strip()
if line != '':
self.idx2char.append(line)
elif dict_list is not None:
self.idx2char = dict_list
else:

View File

@ -4,6 +4,7 @@ from .check_argument import (equal_len, is_2dlist, is_3dlist, is_ndarray_list,
is_none_or_type, is_type_list, valid_boundary)
from .collect_env import collect_env
from .data_convert_util import convert_annotations
from .fileio import list_from_file, list_to_file
from .img_util import drop_orientation, is_not_png
from .lmdb_util import lmdb_converter
from .logger import get_root_logger
@ -12,5 +13,6 @@ __all__ = [
'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env',
'is_3dlist', 'is_ndarray_list', 'is_type_list', 'is_none_or_type',
'equal_len', 'is_2dlist', 'valid_boundary', 'lmdb_converter',
'drop_orientation', 'convert_annotations', 'is_not_png'
'drop_orientation', 'convert_annotations', 'is_not_png', 'list_to_file',
'list_from_file'
]

View File

@ -0,0 +1,31 @@
def list_to_file(filename, lines):
"""Write a list of strings to a text file.
Args:
filename (str): The output filename. It will be created/overwritten.
lines (list(str)): Data to be written.
"""
with open(filename, 'w', encoding='utf-8') as fw:
for line in lines:
fw.write(f'{line}\n')
def list_from_file(filename, encoding='utf-8'):
"""Load a text file and parse the content as a list of strings. The
trailing "\\r" and "\\n" of each line will be removed.
Note:
This will be replaced by mmcv's version after it supports encoding.
Args:
filename (str): Filename.
encoding (str): Encoding used to open the file. Default utf-8.
Returns:
list[str]: A list of strings.
"""
item_list = []
with open(filename, 'r', encoding=encoding) as f:
for line in f:
item_list.append(line.rstrip('\n\r'))
return item_list

View File

@ -5,11 +5,12 @@ from pathlib import Path
import lmdb
from mmocr.utils import list_from_file
def lmdb_converter(img_list, output, batch_size=1000, coding='utf-8'):
# read img_list
with open(img_list) as f:
lines = f.readlines()
def lmdb_converter(img_list_file, output, batch_size=1000, coding='utf-8'):
# read img_list_file
lines = list_from_file(img_list_file)
# create lmdb database
if Path(output).is_dir():

View File

@ -0,0 +1,46 @@
import tempfile
from mmocr.utils import list_from_file, list_to_file
lists = [
[],
[' '],
['\t'],
['a'],
[1],
[1.],
['a', 'b'],
['a', 1, 1.],
[1, 1., 'a'],
['', '啊啊'],
['選択', 'noël', 'Информацией', 'ÄÆä'],
]
def test_list_to_file():
with tempfile.TemporaryDirectory() as tmpdirname:
for i, lines in enumerate(lists):
filename = f'{tmpdirname}/{i}.txt'
list_to_file(filename, lines)
lines2 = [
line.rstrip('\r\n')
for line in open(filename, 'r', encoding='utf-8').readlines()
]
lines = list(map(str, lines))
assert len(lines) == len(lines2)
assert all(line1 == line2 for line1, line2 in zip(lines, lines2))
def test_list_from_file():
with tempfile.TemporaryDirectory() as tmpdirname:
for encoding in ['utf-8', 'utf-8-sig']:
for lineend in ['\n', '\r\n']:
for i, lines in enumerate(lists):
filename = f'{tmpdirname}/{i}.txt'
with open(filename, 'w', encoding=encoding) as f:
f.writelines(f'{line}{lineend}' for line in lines)
lines2 = list_from_file(filename, encoding=encoding)
lines = list(map(str, lines))
assert len(lines) == len(lines2)
assert all(line1 == line2
for line1, line2 in zip(lines, lines2))

View File

@ -1,16 +1,13 @@
import argparse
import codecs
import json
import mmcv
def read_json(fpath):
with codecs.open(fpath, 'r', 'utf-8') as f:
obj = json.load(f)
return obj
from mmocr.utils import list_to_file
def parse_coco_json(in_path):
json_obj = read_json(in_path)
json_obj = mmcv.load(in_path)
image_infos = json_obj['images']
annotations = json_obj['annotations']
imgid2imgname = {}
@ -35,18 +32,17 @@ def parse_coco_json(in_path):
def gen_line_dict_file(out_path, imgid2imgname, imgid2anno):
# import pdb; pdb.set_trace()
with codecs.open(out_path, 'w', 'utf-8') as fw:
for key, value in imgid2imgname.items():
if key in imgid2anno:
anno = imgid2anno[key]
line_dict = {}
line_dict['file_name'] = value['file_name']
line_dict['height'] = value['height']
line_dict['width'] = value['width']
line_dict['annotations'] = anno
line_dict_str = json.dumps(line_dict)
fw.write(line_dict_str + '\n')
lines = []
for key, value in imgid2imgname.items():
if key in imgid2anno:
anno = imgid2anno[key]
line_dict = {}
line_dict['file_name'] = value['file_name']
line_dict['height'] = value['height']
line_dict['width'] = value['width']
line_dict['annotations'] = anno
lines.append(json.dumps(line_dict))
list_to_file(out_path, lines)
def parse_args():

View File

@ -8,7 +8,8 @@ import mmcv
import numpy as np
from shapely.geometry import Polygon
from mmocr.utils import convert_annotations, drop_orientation, is_not_png
from mmocr.utils import (convert_annotations, drop_orientation, is_not_png,
list_from_file)
def collect_files(img_dir, gt_dir, split):
@ -84,11 +85,8 @@ def collect_annotations(files, split, nproc=1):
def load_txt_info(gt_file, img_info):
with open(gt_file) as f:
gt_list = f.readlines()
anno_info = []
for line in gt_list:
for line in list_from_file(gt_file):
# each line has one ploygen (n vetices), and one text.
# e.g., 695,885,866,888,867,1146,696,1143,####Latin 9
line = line.strip()

View File

@ -7,7 +7,8 @@ import mmcv
import numpy as np
from shapely.geometry import Polygon
from mmocr.utils import convert_annotations, drop_orientation, is_not_png
from mmocr.utils import (convert_annotations, drop_orientation, is_not_png,
list_from_file)
def collect_files(img_dir, gt_dir):
@ -96,11 +97,9 @@ def load_img_info(files, dataset):
assert img.shape[0:2] == img_color.shape[0:2]
if dataset == 'icdar2017':
with open(gt_file) as f:
gt_list = f.readlines()
gt_list = list_from_file(gt_file)
elif dataset == 'icdar2015':
with open(gt_file, mode='r', encoding='utf-8-sig') as f:
gt_list = f.readlines()
gt_list = list_from_file(gt_file, encoding='utf-8-sig')
else:
raise NotImplementedError(f'Not support {dataset}')

View File

@ -4,64 +4,62 @@ import os.path as osp
import cv2
from mmocr.utils import list_from_file, list_to_file
def parse_old_label(data_root, in_path, img_size=False):
imgid2imgname = {}
imgid2anno = {}
idx = 0
with open(in_path, 'r') as fr:
for line in fr:
line = line.strip().split()
img_full_path = osp.join(data_root, line[0])
if not osp.exists(img_full_path):
continue
ann_file = osp.join(data_root, line[1])
if not osp.exists(ann_file):
continue
for line in list_from_file(in_path):
line = line.strip().split()
img_full_path = osp.join(data_root, line[0])
if not osp.exists(img_full_path):
continue
ann_file = osp.join(data_root, line[1])
if not osp.exists(ann_file):
continue
img_info = {}
img_info['file_name'] = line[0]
if img_size:
img = cv2.imread(img_full_path)
h, w = img.shape[:2]
img_info['height'] = h
img_info['width'] = w
imgid2imgname[idx] = img_info
img_info = {}
img_info['file_name'] = line[0]
if img_size:
img = cv2.imread(img_full_path)
h, w = img.shape[:2]
img_info['height'] = h
img_info['width'] = w
imgid2imgname[idx] = img_info
imgid2anno[idx] = []
char_annos = []
with open(ann_file, 'r') as fr:
t = 0
for line in fr:
line = line.strip()
if t == 0:
img_info['text'] = line
else:
char_box = [float(x) for x in line.split()]
char_text = img_info['text'][t - 1]
char_ann = dict(char_box=char_box, char_text=char_text)
char_annos.append(char_ann)
t += 1
imgid2anno[idx] = char_annos
idx += 1
imgid2anno[idx] = []
char_annos = []
for t, ann_line in enumerate(list_from_file(ann_file)):
ann_line = ann_line.strip()
if t == 0:
img_info['text'] = ann_line
else:
char_box = [float(x) for x in ann_line.split()]
char_text = img_info['text'][t - 1]
char_ann = dict(char_box=char_box, char_text=char_text)
char_annos.append(char_ann)
imgid2anno[idx] = char_annos
idx += 1
return imgid2imgname, imgid2anno
def gen_line_dict_file(out_path, imgid2imgname, imgid2anno, img_size=False):
with open(out_path, 'w', encoding='utf-8') as fw:
for key, value in imgid2imgname.items():
if key in imgid2anno:
anno = imgid2anno[key]
line_dict = {}
line_dict['file_name'] = value['file_name']
line_dict['text'] = value['text']
if img_size:
line_dict['height'] = value['height']
line_dict['width'] = value['width']
line_dict['annotations'] = anno
line_dict_str = json.dumps(line_dict)
fw.write(line_dict_str + '\n')
lines = []
for key, value in imgid2imgname.items():
if key in imgid2anno:
anno = imgid2anno[key]
line_dict = {}
line_dict['file_name'] = value['file_name']
line_dict['text'] = value['text']
if img_size:
line_dict['height'] = value['height']
line_dict['width'] = value['width']
line_dict['annotations'] = anno
lines.append(json.dumps(line_dict))
list_to_file(out_path, lines)
def parse_args():

View File

@ -5,6 +5,8 @@ import xml.etree.ElementTree as ET
import cv2
from mmocr.utils.fileio import list_to_file
def parse_args():
parser = argparse.ArgumentParser(
@ -43,35 +45,35 @@ def main():
root = tree.getroot()
index = 1
with open(dst_label_file, 'w', encoding='utf-8') as fw:
total_img_num = len(root)
i = 1
for image_node in root.findall('image'):
image_name = image_node.find('imageName').text
print(f'[{i}/{total_img_num}] Process image: {image_name}')
i += 1
lexicon = image_node.find('lex').text.lower()
lexicon_list = lexicon.split(',')
lex_size = len(lexicon_list)
src_img = cv2.imread(osp.join(src_image_root, image_name))
for rectangle in image_node.find('taggedRectangles'):
x = int(rectangle.get('x'))
y = int(rectangle.get('y'))
w = int(rectangle.get('width'))
h = int(rectangle.get('height'))
rb, re = max(0, y), max(0, y + h)
cb, ce = max(0, x), max(0, x + w)
dst_img = src_img[rb:re, cb:ce]
text_label = rectangle.find('tag').text.lower()
if args.resize:
dst_img = cv2.resize(dst_img, (args.width, args.height))
dst_img_name = f'img_{index:04}' + '.jpg'
index += 1
dst_img_path = osp.join(dst_image_root, dst_img_name)
cv2.imwrite(dst_img_path, dst_img)
fw.write(f'{osp.basename(dst_image_root)}/{dst_img_name} '
f'{text_label} {lex_size} {lexicon}\n')
lines = []
total_img_num = len(root)
i = 1
for image_node in root.findall('image'):
image_name = image_node.find('imageName').text
print(f'[{i}/{total_img_num}] Process image: {image_name}')
i += 1
lexicon = image_node.find('lex').text.lower()
lexicon_list = lexicon.split(',')
lex_size = len(lexicon_list)
src_img = cv2.imread(osp.join(src_image_root, image_name))
for rectangle in image_node.find('taggedRectangles'):
x = int(rectangle.get('x'))
y = int(rectangle.get('y'))
w = int(rectangle.get('width'))
h = int(rectangle.get('height'))
rb, re = max(0, y), max(0, y + h)
cb, ce = max(0, x), max(0, x + w)
dst_img = src_img[rb:re, cb:ce]
text_label = rectangle.find('tag').text.lower()
if args.resize:
dst_img = cv2.resize(dst_img, (args.width, args.height))
dst_img_name = f'img_{index:04}' + '.jpg'
index += 1
dst_img_path = osp.join(dst_image_root, dst_img_name)
cv2.imwrite(dst_img_path, dst_img)
lines.append(f'{osp.basename(dst_image_root)}/{dst_img_name} '
f'{text_label} {lex_size} {lexicon}')
list_to_file(dst_label_file, lines)
print(f'Finish to generate svt testset, '
f'with label file {dst_label_file}')

View File

@ -3,6 +3,7 @@ import os.path as osp
import shutil
import time
from argparse import ArgumentParser
from itertools import compress
import mmcv
import torch
@ -13,7 +14,7 @@ from mmocr.apis import model_inference
from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
from mmocr.datasets import build_dataset # noqa: F401
from mmocr.models import build_detector # noqa: F401
from mmocr.utils import get_root_logger
from mmocr.utils import get_root_logger, list_from_file, list_to_file
def save_results(img_paths, pred_labels, gt_labels, res_dir):
@ -26,21 +27,15 @@ def save_results(img_paths, pred_labels, gt_labels, res_dir):
res_dir (str)
"""
assert len(img_paths) == len(pred_labels) == len(gt_labels)
res_file = osp.join(res_dir, 'results.txt')
correct_file = osp.join(res_dir, 'correct.txt')
wrong_file = osp.join(res_dir, 'wrong.txt')
with open(res_file, 'w') as fw, \
open(correct_file, 'w') as fw_correct, \
open(wrong_file, 'w') as fw_wrong:
for img_path, pred_label, gt_label in zip(img_paths, pred_labels,
gt_labels):
fw.write(img_path + ' ' + pred_label + ' ' + gt_label + '\n')
if pred_label == gt_label:
fw_correct.write(img_path + ' ' + pred_label + ' ' + gt_label +
'\n')
else:
fw_wrong.write(img_path + ' ' + pred_label + ' ' + gt_label +
'\n')
corrects = [pred == gt for pred, gt in zip(pred_labels, gt_labels)]
wrongs = [not c for c in corrects]
lines = [
f'{img} {pred} {gt}'
for img, pred, gt in zip(img_paths, pred_labels, gt_labels)
]
list_to_file(osp.join(res_dir, 'results.txt'), lines)
list_to_file(osp.join(res_dir, 'correct.txt'), compress(lines, corrects))
list_to_file(osp.join(res_dir, 'wrong.txt'), compress(lines, wrongs))
def main():
@ -80,39 +75,38 @@ def main():
total_img_num = sum([1 for _ in open(args.img_list)])
progressbar = ProgressBar(task_num=total_img_num)
num_gt_label = 0
with open(args.img_list, 'r') as fr:
for line in fr:
progressbar.update()
item_list = line.strip().split()
img_file = item_list[0]
gt_label = ''
if len(item_list) >= 2:
gt_label = item_list[1]
num_gt_label += 1
img_path = osp.join(args.img_root_path, img_file)
if not osp.exists(img_path):
raise FileNotFoundError(img_path)
# Test a single image
result = model_inference(model, img_path)
pred_label = result['text']
for line in list_from_file(args.img_list):
progressbar.update()
item_list = line.strip().split()
img_file = item_list[0]
gt_label = ''
if len(item_list) >= 2:
gt_label = item_list[1]
num_gt_label += 1
img_path = osp.join(args.img_root_path, img_file)
if not osp.exists(img_path):
raise FileNotFoundError(img_path)
# Test a single image
result = model_inference(model, img_path)
pred_label = result['text']
out_img_name = '_'.join(img_file.split('/'))
out_file = osp.join(out_vis_dir, out_img_name)
kwargs_dict = {
'gt_label': gt_label,
'show': args.show,
'out_file': '' if args.show else out_file
}
model.show_result(img_path, result, **kwargs_dict)
if gt_label != '':
if gt_label == pred_label:
dst_file = osp.join(correct_vis_dir, out_img_name)
else:
dst_file = osp.join(wrong_vis_dir, out_img_name)
shutil.copy(out_file, dst_file)
img_paths.append(img_path)
gt_labels.append(gt_label)
pred_labels.append(pred_label)
out_img_name = '_'.join(img_file.split('/'))
out_file = osp.join(out_vis_dir, out_img_name)
kwargs_dict = {
'gt_label': gt_label,
'show': args.show,
'out_file': '' if args.show else out_file
}
model.show_result(img_path, result, **kwargs_dict)
if gt_label != '':
if gt_label == pred_label:
dst_file = osp.join(correct_vis_dir, out_img_name)
else:
dst_file = osp.join(wrong_vis_dir, out_img_name)
shutil.copy(out_file, dst_file)
img_paths.append(img_path)
gt_labels.append(gt_label)
pred_labels.append(pred_label)
# Save results
save_results(img_paths, pred_labels, gt_labels, args.out_dir)

View File

@ -1,5 +1,4 @@
#!/usr/bin/env python
import codecs
import os.path as osp
from argparse import ArgumentParser
@ -11,6 +10,7 @@ from mmcv.utils import ProgressBar
from mmdet.apis import inference_detector, init_detector
from mmocr.core.evaluation.utils import filter_result
from mmocr.models import build_detector # noqa: F401
from mmocr.utils import list_from_file, list_to_file
def gen_target_path(target_root_path, src_name, suffix):
@ -25,9 +25,9 @@ def gen_target_path(target_root_path, src_name, suffix):
assert isinstance(src_name, str)
assert isinstance(suffix, str)
dir_name, file_name = osp.split(src_name)
name, file_suffix = osp.splitext(file_name)
return target_root_path + '/' + name + suffix
file_name = osp.split(src_name)[-1]
name = osp.splitext(file_name)[0]
return osp.join(target_root_path, name + suffix)
def save_2darray(mat, file_name):
@ -37,10 +37,8 @@ def save_2darray(mat, file_name):
mat (ndarray): 2d-array of shape (n, m).
file_name (str): The output file name.
"""
with codecs.open(file_name, 'w', 'utf-8') as fw:
for row in mat:
row_str = ','.join([str(x) for x in row])
fw.write(row_str + '\n')
lines = [','.join([str(x) for x in row]) for row in mat]
list_to_file(file_name, lines)
def save_bboxes_quadrangles(bboxes_with_scores,
@ -144,22 +142,21 @@ def main():
total_img_num = sum([1 for _ in open(args.img_list)])
progressbar = ProgressBar(task_num=total_img_num)
with codecs.open(args.img_list, 'r', 'utf-8') as fr:
for line in fr:
progressbar.update()
img_path = args.img_root + '/' + line.strip()
if not osp.exists(img_path):
raise FileNotFoundError(img_path)
# Test a single image
result = inference_detector(model, img_path)
img_name = osp.basename(img_path)
out_file = osp.join(out_vis_dir, img_name)
kwargs_dict = {
'score_thr': args.score_thr,
'show': False,
'out_file': out_file
}
model.show_result(img_path, result, **kwargs_dict)
for line in list_from_file(args.img_list):
progressbar.update()
img_path = osp.join(args.img_root, line.strip())
if not osp.exists(img_path):
raise FileNotFoundError(img_path)
# Test a single image
result = inference_detector(model, img_path)
img_name = osp.basename(img_path)
out_file = osp.join(out_vis_dir, img_name)
kwargs_dict = {
'score_thr': args.score_thr,
'show': False,
'out_file': out_file
}
model.show_result(img_path, result, **kwargs_dict)
print(f'\nInference done, and results saved in {args.out_dir}\n')