mirror of https://github.com/open-mmlab/mmyolo.git
161 lines
4.9 KiB
Python
161 lines
4.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
"""Extracting subsets from coco2017 dataset.
|
|
|
|
This script is mainly used to debug and verify the correctness of the
|
|
program quickly.
|
|
The root folder format must be in the following format:
|
|
|
|
├── root
|
|
│ ├── annotations
|
|
│ ├── train2017
|
|
│ ├── val2017
|
|
│ ├── test2017
|
|
|
|
Currently, only support COCO2017. In the future will support user-defined
|
|
datasets of standard coco JSON format.
|
|
|
|
Example:
|
|
python tools/misc/extract_subcoco.py ${ROOT} ${OUT_DIR} --num-img ${NUM_IMG}
|
|
"""
|
|
|
|
import argparse
|
|
import os.path as osp
|
|
import shutil
|
|
|
|
import mmengine
|
|
import numpy as np
|
|
from pycocotools.coco import COCO
|
|
|
|
|
|
# TODO: Currently only supports coco2017
|
|
def _process_data(args,
|
|
in_dataset_type: str,
|
|
out_dataset_type: str,
|
|
year: str = '2017'):
|
|
assert in_dataset_type in ('train', 'val')
|
|
assert out_dataset_type in ('train', 'val')
|
|
|
|
int_ann_file_name = f'annotations/instances_{in_dataset_type}{year}.json'
|
|
out_ann_file_name = f'annotations/instances_{out_dataset_type}{year}.json'
|
|
|
|
ann_path = osp.join(args.root, int_ann_file_name)
|
|
json_data = mmengine.load(ann_path)
|
|
|
|
new_json_data = {
|
|
'info': json_data['info'],
|
|
'licenses': json_data['licenses'],
|
|
'categories': json_data['categories'],
|
|
'images': [],
|
|
'annotations': []
|
|
}
|
|
|
|
area_dict = {
|
|
'small': [0., 32 * 32],
|
|
'medium': [32 * 32, 96 * 96],
|
|
'large': [96 * 96, float('inf')]
|
|
}
|
|
|
|
coco = COCO(ann_path)
|
|
|
|
# filter annotations by category ids and area range
|
|
areaRng = area_dict[args.area_size] if args.area_size else []
|
|
catIds = coco.getCatIds(args.classes) if args.classes else []
|
|
ann_ids = coco.getAnnIds(catIds=catIds, areaRng=areaRng)
|
|
ann_info = coco.loadAnns(ann_ids)
|
|
|
|
# get image ids by anns set
|
|
filter_img_ids = {ann['image_id'] for ann in ann_info}
|
|
filter_img = coco.loadImgs(filter_img_ids)
|
|
|
|
# shuffle
|
|
np.random.shuffle(filter_img)
|
|
|
|
num_img = args.num_img if args.num_img > 0 else len(filter_img)
|
|
if num_img > len(filter_img):
|
|
print(
|
|
f'num_img is too big, will be set to {len(filter_img)}, '
|
|
'because of not enough image after filter by classes and area_size'
|
|
)
|
|
num_img = len(filter_img)
|
|
|
|
progress_bar = mmengine.ProgressBar(num_img)
|
|
|
|
for i in range(num_img):
|
|
file_name = filter_img[i]['file_name']
|
|
image_path = osp.join(args.root, in_dataset_type + year, file_name)
|
|
|
|
ann_ids = coco.getAnnIds(
|
|
imgIds=[filter_img[i]['id']], catIds=catIds, areaRng=areaRng)
|
|
img_ann_info = coco.loadAnns(ann_ids)
|
|
|
|
new_json_data['images'].append(filter_img[i])
|
|
new_json_data['annotations'].extend(img_ann_info)
|
|
|
|
shutil.copy(image_path, osp.join(args.out_dir,
|
|
out_dataset_type + year))
|
|
|
|
progress_bar.update()
|
|
|
|
mmengine.dump(new_json_data, osp.join(args.out_dir, out_ann_file_name))
|
|
|
|
|
|
def _make_dirs(out_dir):
|
|
mmengine.mkdir_or_exist(out_dir)
|
|
mmengine.mkdir_or_exist(osp.join(out_dir, 'annotations'))
|
|
mmengine.mkdir_or_exist(osp.join(out_dir, 'train2017'))
|
|
mmengine.mkdir_or_exist(osp.join(out_dir, 'val2017'))
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Extract coco subset')
|
|
parser.add_argument('root', help='root path')
|
|
parser.add_argument(
|
|
'out_dir', type=str, help='directory where subset coco will be saved.')
|
|
parser.add_argument(
|
|
'--num-img',
|
|
default=50,
|
|
type=int,
|
|
help='num of extract image, -1 means all images')
|
|
parser.add_argument(
|
|
'--area-size',
|
|
choices=['small', 'medium', 'large'],
|
|
help='filter ground-truth info by area size')
|
|
parser.add_argument(
|
|
'--classes', nargs='+', help='filter ground-truth by class name')
|
|
parser.add_argument(
|
|
'--use-training-set',
|
|
action='store_true',
|
|
help='Whether to use the training set when extract the training set. '
|
|
'The training subset is extracted from the validation set by '
|
|
'default which can speed up.')
|
|
parser.add_argument('--seed', default=-1, type=int, help='seed')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
assert args.out_dir != args.root, \
|
|
'The file will be overwritten in place, ' \
|
|
'so the same folder is not allowed !'
|
|
|
|
seed = int(args.seed)
|
|
if seed != -1:
|
|
print(f'Set the global seed: {seed}')
|
|
np.random.seed(int(args.seed))
|
|
|
|
_make_dirs(args.out_dir)
|
|
|
|
print('====Start processing train dataset====')
|
|
if args.use_training_set:
|
|
_process_data(args, 'train', 'train')
|
|
else:
|
|
_process_data(args, 'val', 'train')
|
|
print('\n====Start processing val dataset====')
|
|
_process_data(args, 'val', 'val')
|
|
print(f'\n Result save to {args.out_dir}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|