mirror of https://github.com/open-mmlab/mmocr.git
117 lines
3.9 KiB
Python
117 lines
3.9 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
import math
|
||
|
import os
|
||
|
import os.path as osp
|
||
|
from argparse import ArgumentParser
|
||
|
from functools import partial
|
||
|
|
||
|
import mmcv
|
||
|
from PIL import Image
|
||
|
|
||
|
from mmocr.utils.fileio import list_to_file
|
||
|
|
||
|
|
||
|
def parse_args():
|
||
|
parser = ArgumentParser(description='Generate training and validation set '
|
||
|
'of OpenVINO annotations for Open '
|
||
|
'Images by cropping box image.')
|
||
|
parser.add_argument(
|
||
|
'root_path', help='Root dir containing images and annotations')
|
||
|
parser.add_argument(
|
||
|
'n_proc', default=1, type=int, help='Number of processes to run')
|
||
|
args = parser.parse_args()
|
||
|
return args
|
||
|
|
||
|
|
||
|
def process_img(args, src_image_root, dst_image_root):
|
||
|
# Dirty hack for multi-processing
|
||
|
img_idx, img_info, anns = args
|
||
|
src_img = Image.open(osp.join(src_image_root, img_info['file_name']))
|
||
|
labels = []
|
||
|
for ann_idx, ann in enumerate(anns):
|
||
|
attrs = ann['attributes']
|
||
|
text_label = attrs['transcription']
|
||
|
|
||
|
# Ignore illegible or non-English words
|
||
|
if not attrs['legible'] or attrs['language'] != 'english':
|
||
|
continue
|
||
|
|
||
|
x, y, w, h = ann['bbox']
|
||
|
x, y = max(0, math.floor(x)), max(0, math.floor(y))
|
||
|
w, h = math.ceil(w), math.ceil(h)
|
||
|
dst_img = src_img.crop((x, y, x + w, y + h))
|
||
|
dst_img_name = f'img_{img_idx}_{ann_idx}.jpg'
|
||
|
dst_img_path = osp.join(dst_image_root, dst_img_name)
|
||
|
# Preserve JPEG quality
|
||
|
dst_img.save(dst_img_path, qtables=src_img.quantization)
|
||
|
labels.append(f'{osp.basename(dst_image_root)}/{dst_img_name}'
|
||
|
f' {text_label}')
|
||
|
src_img.close()
|
||
|
return labels
|
||
|
|
||
|
|
||
|
def convert_openimages(root_path,
|
||
|
dst_image_path,
|
||
|
dst_label_filename,
|
||
|
annotation_filename,
|
||
|
img_start_idx=0,
|
||
|
nproc=1):
|
||
|
annotation_path = osp.join(root_path, annotation_filename)
|
||
|
if not osp.exists(annotation_path):
|
||
|
raise Exception(
|
||
|
f'{annotation_path} not exists, please check and try again.')
|
||
|
src_image_root = root_path
|
||
|
|
||
|
# outputs
|
||
|
dst_label_file = osp.join(root_path, dst_label_filename)
|
||
|
dst_image_root = osp.join(root_path, dst_image_path)
|
||
|
os.makedirs(dst_image_root, exist_ok=True)
|
||
|
|
||
|
annotation = mmcv.load(annotation_path)
|
||
|
|
||
|
process_img_with_path = partial(
|
||
|
process_img,
|
||
|
src_image_root=src_image_root,
|
||
|
dst_image_root=dst_image_root)
|
||
|
tasks = []
|
||
|
anns = {}
|
||
|
for ann in annotation['annotations']:
|
||
|
anns.setdefault(ann['image_id'], []).append(ann)
|
||
|
for img_idx, img_info in enumerate(annotation['images']):
|
||
|
tasks.append((img_idx + img_start_idx, img_info, anns[img_info['id']]))
|
||
|
labels_list = mmcv.track_parallel_progress(
|
||
|
process_img_with_path, tasks, keep_order=True, nproc=nproc)
|
||
|
final_labels = []
|
||
|
for label_list in labels_list:
|
||
|
final_labels += label_list
|
||
|
list_to_file(dst_label_file, final_labels)
|
||
|
return len(annotation['images'])
|
||
|
|
||
|
|
||
|
def main():
|
||
|
args = parse_args()
|
||
|
root_path = args.root_path
|
||
|
print('Processing training set...')
|
||
|
num_train_imgs = 0
|
||
|
for s in '125f':
|
||
|
num_train_imgs = convert_openimages(
|
||
|
root_path=root_path,
|
||
|
dst_image_path=f'image_{s}',
|
||
|
dst_label_filename=f'train_{s}_label.txt',
|
||
|
annotation_filename=f'text_spotting_openimages_v5_train_{s}.json',
|
||
|
img_start_idx=num_train_imgs,
|
||
|
nproc=args.n_proc)
|
||
|
print('Processing validation set...')
|
||
|
convert_openimages(
|
||
|
root_path=root_path,
|
||
|
dst_image_path='image_val',
|
||
|
dst_label_filename='val_label.txt',
|
||
|
annotation_filename='text_spotting_openimages_v5_validation.json',
|
||
|
img_start_idx=num_train_imgs,
|
||
|
nproc=args.n_proc)
|
||
|
print('Finish')
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|