mirror of https://github.com/alibaba/EasyCV.git
192 lines
7.5 KiB
Python
192 lines
7.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
#
|
|
################################################################################
|
|
"""
|
|
This script can be used to extract the VOC2007 and VOC2012 dataset files
|
|
[data, labels] from the given annotations that can be used for training. The
|
|
files can be prepared for various data splits
|
|
"""
|
|
|
|
from __future__ import (absolute_import, division, print_function,
|
|
unicode_literals)
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import sys
|
|
from glob import glob
|
|
|
|
import numpy as np
|
|
|
|
# initiate the logger
|
|
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
|
|
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def validate_files(input_files):
|
|
"""
|
|
The valid files will have name: <class_name>_<split>.txt. We want to remove
|
|
all the other files from the input.
|
|
"""
|
|
output_files = []
|
|
for item in input_files:
|
|
if len(item.split('/')[-1].split('_')) == 2:
|
|
output_files.append(item)
|
|
return output_files
|
|
|
|
|
|
def get_data_files(split, args):
|
|
data_dir = os.path.join(args.data_source_dir, 'ImageSets/Main')
|
|
assert os.path.exists(data_dir), "Data: {} doesn't exist".format(data_dir)
|
|
test_data_files = glob(os.path.join(data_dir, '*_test.txt'))
|
|
test_data_files = validate_files(test_data_files)
|
|
if args.separate_partitions > 0:
|
|
train_data_files = glob(os.path.join(data_dir, '*_train.txt'))
|
|
val_data_files = glob(os.path.join(data_dir, '*_val.txt'))
|
|
train_data_files = validate_files(train_data_files)
|
|
val_data_files = validate_files(val_data_files)
|
|
assert len(train_data_files) == len(val_data_files)
|
|
if split == 'train':
|
|
data_files = train_data_files
|
|
elif split == 'test':
|
|
data_files = test_data_files
|
|
else:
|
|
data_files = val_data_files
|
|
else:
|
|
train_data_files = glob(os.path.join(data_dir, '*_trainval.txt'))
|
|
if len(test_data_files) == 0:
|
|
# For VOC2012 dataset, we have trainval, val and train data.
|
|
train_data_files = glob(os.path.join(data_dir, '*_train.txt'))
|
|
test_data_files = glob(os.path.join(data_dir, '*_val.txt'))
|
|
test_data_files = validate_files(test_data_files)
|
|
train_data_files = validate_files(train_data_files)
|
|
data_files = train_data_files if (split
|
|
== 'train') else test_data_files
|
|
assert len(train_data_files) == len(test_data_files), 'Missing classes'
|
|
return data_files
|
|
|
|
|
|
def get_images_labels_info(split, args):
|
|
assert os.path.exists(args.data_source_dir), 'Data source NOT found. Abort'
|
|
|
|
data_files = get_data_files(split, args)
|
|
# we will construct a map for image name to the vector of -1, 0, 1
|
|
# we sort the data_files which gives sorted class names as well
|
|
img_labels_map = {}
|
|
for cls_num, data_path in enumerate(sorted(data_files)):
|
|
# for this class, we have images and each image will have label
|
|
# 1, -1, 0 -> present, not present, ignore respectively as in VOC data.
|
|
with open(data_path, 'r') as fopen:
|
|
for line in fopen:
|
|
try:
|
|
img_name, orig_label = line.strip().split()
|
|
if img_name not in img_labels_map:
|
|
img_labels_map[img_name] = -np.ones(
|
|
len(data_files), dtype=np.int32)
|
|
orig_label = int(orig_label)
|
|
# in VOC data, -1 (not present), set it to 0 as train target
|
|
if orig_label == -1:
|
|
orig_label = 0
|
|
# in VOC data, 0 (ignore), set it to -1 as train target
|
|
elif orig_label == 0:
|
|
orig_label = -1
|
|
img_labels_map[img_name][cls_num] = orig_label
|
|
except Exception:
|
|
logger.info('Error processing: {} data_path: {}'.format(
|
|
line, data_path))
|
|
|
|
img_paths, img_labels = [], []
|
|
for item in sorted(img_labels_map.keys()):
|
|
img_paths.append(
|
|
os.path.join(args.data_source_dir, 'JPEGImages', item + '.jpg'))
|
|
img_labels.append(img_labels_map[item])
|
|
|
|
output_dict = {}
|
|
if args.generate_json:
|
|
cls_names = []
|
|
for item in sorted(data_files):
|
|
name = item.split('/')[-1].split('.')[0].split('_')[0]
|
|
cls_names.append(name)
|
|
|
|
img_ids, json_img_labels = [], []
|
|
for item in sorted(img_labels_map.keys()):
|
|
img_ids.append(item)
|
|
json_img_labels.append(img_labels_map[item])
|
|
|
|
for img_idx in range(len(img_ids)):
|
|
img_id = img_ids[img_idx]
|
|
out_lbl = {}
|
|
for cls_idx in range(len(cls_names)):
|
|
name = cls_names[cls_idx]
|
|
out_lbl[name] = int(json_img_labels[img_idx][cls_idx])
|
|
output_dict[img_id] = out_lbl
|
|
return img_paths, img_labels, output_dict
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Create VOC data files')
|
|
parser.add_argument(
|
|
'--data_source_dir',
|
|
type=str,
|
|
default=None,
|
|
help='Path to data directory containing ImageSets and JPEGImages')
|
|
parser.add_argument(
|
|
'--output_dir',
|
|
type=str,
|
|
default=None,
|
|
help='Output directory where images/label information will be written')
|
|
parser.add_argument(
|
|
'--separate_partitions',
|
|
type=int,
|
|
default=0,
|
|
help='Whether to create files separately for partitions train/test/val'
|
|
)
|
|
parser.add_argument(
|
|
'--generate_json',
|
|
type=int,
|
|
default=0,
|
|
help='Whether to json files for partitions train/test/val')
|
|
args = parser.parse_args()
|
|
|
|
# given the data directory for the partitions train, val, and test, we will
|
|
# write numpy files for each partition.
|
|
partitions = ['train', 'test']
|
|
if args.separate_partitions > 0:
|
|
partitions.append('val')
|
|
|
|
for partition in partitions:
|
|
logger.info(
|
|
'========Preparing {} data files========'.format(partition))
|
|
imgs_info, lbls_info, output_dict = get_images_labels_info(
|
|
partition, args)
|
|
img_info_out_path = os.path.join(args.output_dir,
|
|
partition + '_images.npy')
|
|
label_info_out_path = os.path.join(args.output_dir,
|
|
partition + '_labels.npy')
|
|
logger.info(
|
|
'=================SAVING DATA files=======================')
|
|
logger.info('partition: {} saving img_paths to: {}'.format(
|
|
partition, img_info_out_path))
|
|
logger.info('partition: {} saving lbls_paths: {}'.format(
|
|
partition, label_info_out_path))
|
|
logger.info('partition: {} imgs: {}'.format(partition,
|
|
np.array(imgs_info).shape))
|
|
np.save(img_info_out_path, np.array(imgs_info))
|
|
np.save(label_info_out_path, np.array(lbls_info))
|
|
if args.generate_json:
|
|
json_out_path = os.path.join(args.output_dir,
|
|
partition + '_targets.json')
|
|
import json
|
|
with open(json_out_path, 'w') as fp:
|
|
json.dump(output_dict, fp)
|
|
logger.info('Saved Json to: {}'.format(json_out_path))
|
|
logger.info('DONE!')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|