2022-11-26 09:27:30 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
import argparse
|
|
|
|
import json
|
|
|
|
import random
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
from pycocotools.coco import COCO
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
|
|
'--json', type=str, required=True, help='COCO json label path')
|
|
|
|
parser.add_argument(
|
|
|
|
'--out-dir', type=str, required=True, help='output path')
|
|
|
|
parser.add_argument(
|
|
|
|
'--ratios',
|
|
|
|
nargs='+',
|
|
|
|
type=float,
|
|
|
|
help='ratio for sub dataset, if set 2 number then will generate '
|
|
|
|
'trainval + test (eg. "0.8 0.1 0.1" or "2 1 1"), if set 3 number '
|
|
|
|
'then will generate train + val + test (eg. "0.85 0.15" or "2 1")')
|
|
|
|
parser.add_argument(
|
|
|
|
'--shuffle',
|
|
|
|
action='store_true',
|
|
|
|
help='Whether to display in disorder')
|
|
|
|
parser.add_argument('--seed', default=-1, type=int, help='seed')
|
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
def split_coco_dataset(coco_json_path: str, save_dir: str, ratios: list,
|
|
|
|
shuffle: bool, seed: int):
|
|
|
|
if not Path(coco_json_path).exists():
|
|
|
|
raise FileNotFoundError(f'Can not not found {coco_json_path}')
|
|
|
|
|
|
|
|
if not Path(save_dir).exists():
|
|
|
|
Path(save_dir).mkdir(parents=True)
|
|
|
|
|
|
|
|
# ratio normalize
|
|
|
|
ratios = np.array(ratios) / np.array(ratios).sum()
|
|
|
|
|
|
|
|
if len(ratios) == 2:
|
|
|
|
ratio_train, ratio_test = ratios
|
|
|
|
ratio_val = 0
|
|
|
|
train_type = 'trainval'
|
|
|
|
elif len(ratios) == 3:
|
|
|
|
ratio_train, ratio_val, ratio_test = ratios
|
|
|
|
train_type = 'train'
|
|
|
|
else:
|
|
|
|
raise ValueError('ratios must set 2 or 3 group!')
|
|
|
|
|
|
|
|
# Read coco info
|
|
|
|
coco = COCO(coco_json_path)
|
|
|
|
coco_image_ids = coco.getImgIds()
|
|
|
|
|
|
|
|
# gen image number of each dataset
|
|
|
|
val_image_num = int(len(coco_image_ids) * ratio_val)
|
|
|
|
test_image_num = int(len(coco_image_ids) * ratio_test)
|
|
|
|
train_image_num = len(coco_image_ids) - val_image_num - test_image_num
|
|
|
|
print('Split info: ====== \n'
|
|
|
|
f'Train ratio = {ratio_train}, number = {train_image_num}\n'
|
2022-11-30 18:58:43 +08:00
|
|
|
f'Val ratio = {ratio_val}, number = {val_image_num}\n'
|
2022-11-26 09:27:30 +08:00
|
|
|
f'Test ratio = {ratio_test}, number = {test_image_num}')
|
|
|
|
|
|
|
|
seed = int(seed)
|
|
|
|
if seed != -1:
|
|
|
|
print(f'Set the global seed: {seed}')
|
|
|
|
np.random.seed(seed)
|
|
|
|
|
|
|
|
if shuffle:
|
|
|
|
print('shuffle dataset.')
|
|
|
|
random.shuffle(coco_image_ids)
|
|
|
|
|
|
|
|
# split each dataset
|
|
|
|
train_image_ids = coco_image_ids[:train_image_num]
|
|
|
|
if val_image_num != 0:
|
|
|
|
val_image_ids = coco_image_ids[train_image_num:train_image_num +
|
|
|
|
val_image_num]
|
|
|
|
else:
|
|
|
|
val_image_ids = None
|
|
|
|
test_image_ids = coco_image_ids[train_image_num + val_image_num:]
|
|
|
|
|
|
|
|
# Save new json
|
|
|
|
categories = coco.loadCats(coco.getCatIds())
|
|
|
|
for img_id_list in [train_image_ids, val_image_ids, test_image_ids]:
|
|
|
|
if img_id_list is None:
|
|
|
|
continue
|
|
|
|
|
|
|
|
# Gen new json
|
|
|
|
img_dict = {
|
|
|
|
'images': coco.loadImgs(ids=img_id_list),
|
|
|
|
'categories': categories,
|
|
|
|
'annotations': coco.loadAnns(coco.getAnnIds(imgIds=img_id_list))
|
|
|
|
}
|
|
|
|
|
|
|
|
# save json
|
|
|
|
if img_id_list == train_image_ids:
|
|
|
|
json_file_path = Path(save_dir, f'{train_type}.json')
|
|
|
|
elif img_id_list == val_image_ids:
|
|
|
|
json_file_path = Path(save_dir, 'val.json')
|
|
|
|
elif img_id_list == test_image_ids:
|
|
|
|
json_file_path = Path(save_dir, 'test.json')
|
|
|
|
else:
|
|
|
|
raise ValueError('img_id_list ERROR!')
|
|
|
|
|
|
|
|
print(f'Saving json to {json_file_path}')
|
|
|
|
with open(json_file_path, 'w') as f_json:
|
|
|
|
json.dump(img_dict, f_json, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
print('All done!')
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
args = parse_args()
|
|
|
|
split_coco_dataset(args.json, args.out_dir, args.ratios, args.shuffle,
|
|
|
|
args.seed)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|