mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] add download and convert script of dataset (#11)
parent
8ce6886994
commit
71dfeb335f
|
@ -0,0 +1,58 @@
|
|||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
|
||||
|
||||
def convert_balloon_to_coco(ann_file, out_file, image_prefix):
|
||||
|
||||
data_infos = mmengine.load(ann_file)
|
||||
|
||||
annotations = []
|
||||
images = []
|
||||
obj_count = 0
|
||||
for idx, v in enumerate(mmengine.track_iter_progress(data_infos.values())):
|
||||
filename = v['filename']
|
||||
img_path = osp.join(image_prefix, filename)
|
||||
height, width = mmcv.imread(img_path).shape[:2]
|
||||
|
||||
images.append(
|
||||
dict(id=idx, file_name=filename, height=height, width=width))
|
||||
|
||||
for _, obj in v['regions'].items():
|
||||
assert not obj['region_attributes']
|
||||
obj = obj['shape_attributes']
|
||||
px = obj['all_points_x']
|
||||
py = obj['all_points_y']
|
||||
poly = [(x + 0.5, y + 0.5) for x, y in zip(px, py)]
|
||||
poly = [p for x in poly for p in x]
|
||||
|
||||
x_min, y_min, x_max, y_max = (min(px), min(py), max(px), max(py))
|
||||
|
||||
data_anno = dict(
|
||||
image_id=idx,
|
||||
id=obj_count,
|
||||
category_id=0,
|
||||
bbox=[x_min, y_min, x_max - x_min, y_max - y_min],
|
||||
area=(x_max - x_min) * (y_max - y_min),
|
||||
segmentation=[poly],
|
||||
iscrowd=0)
|
||||
annotations.append(data_anno)
|
||||
obj_count += 1
|
||||
|
||||
coco_format_json = dict(
|
||||
images=images,
|
||||
annotations=annotations,
|
||||
categories=[{
|
||||
'id': 0,
|
||||
'name': 'balloon'
|
||||
}])
|
||||
mmengine.dump(coco_format_json, out_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
convert_balloon_to_coco('data/balloon/train/via_region_data.json',
|
||||
'data/balloon/train.json', 'data/balloon/train/')
|
||||
convert_balloon_to_coco('data/balloon/val/via_region_data.json',
|
||||
'data/balloon/val.json', 'data/balloon/val/')
|
|
@ -0,0 +1,105 @@
|
|||
import argparse
|
||||
from itertools import repeat
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from tarfile import TarFile
|
||||
from zipfile import ZipFile
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Download datasets for training')
|
||||
parser.add_argument(
|
||||
'--dataset-name', type=str, help='dataset name', default='coco2017')
|
||||
parser.add_argument(
|
||||
'--save-dir',
|
||||
type=str,
|
||||
help='the dir to save dataset',
|
||||
default='data/coco')
|
||||
parser.add_argument(
|
||||
'--unzip',
|
||||
action='store_true',
|
||||
help='whether unzip dataset or not, zipped files will be saved')
|
||||
parser.add_argument(
|
||||
'--delete',
|
||||
action='store_true',
|
||||
help='delete the download zipped files')
|
||||
parser.add_argument(
|
||||
'--threads', type=int, help='number of threading', default=4)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def download(url, dir, unzip=True, delete=False, threads=1):
|
||||
|
||||
def download_one(url, dir):
|
||||
f = dir / Path(url).name
|
||||
if Path(url).is_file():
|
||||
Path(url).rename(f)
|
||||
elif not f.exists():
|
||||
print(f'Downloading {url} to {f}')
|
||||
torch.hub.download_url_to_file(url, f, progress=True)
|
||||
if unzip and f.suffix in ('.zip', '.tar'):
|
||||
print(f'Unzipping {f.name}')
|
||||
if f.suffix == '.zip':
|
||||
ZipFile(f).extractall(path=dir)
|
||||
elif f.suffix == '.tar':
|
||||
TarFile(f).extractall(path=dir)
|
||||
if delete:
|
||||
f.unlink()
|
||||
print(f'Delete {f}')
|
||||
|
||||
dir = Path(dir)
|
||||
if threads > 1:
|
||||
pool = ThreadPool(threads)
|
||||
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir)))
|
||||
pool.close()
|
||||
pool.join()
|
||||
else:
|
||||
for u in [url] if isinstance(url, (str, Path)) else url:
|
||||
download_one(u, dir)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
path = Path(args.save_dir)
|
||||
if not path.exists():
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
data2url = dict(
|
||||
# TODO: Support for downloading Panoptic Segmentation of COCO
|
||||
coco2017=[
|
||||
'http://images.cocodataset.org/zips/train2017.zip',
|
||||
'http://images.cocodataset.org/zips/val2017.zip',
|
||||
'http://images.cocodataset.org/zips/test2017.zip',
|
||||
'http://images.cocodataset.org/annotations/' +
|
||||
'annotations_trainval2017.zip'
|
||||
],
|
||||
lvis=[
|
||||
'https://s3-us-west-2.amazonaws.com/dl.fbaipublicfiles.com/LVIS/lvis_v1_train.json.zip', # noqa
|
||||
'https://s3-us-west-2.amazonaws.com/dl.fbaipublicfiles.com/LVIS/lvis_v1_train.json.zip', # noqa
|
||||
],
|
||||
voc2007=[
|
||||
'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', # noqa
|
||||
'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar', # noqa
|
||||
'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCdevkit_08-Jun-2007.tar', # noqa
|
||||
],
|
||||
balloon=[
|
||||
'https://github.com/matterport/Mask_RCNN/' +
|
||||
'releases/download/v2.1/balloon_dataset.zip'
|
||||
])
|
||||
url = data2url.get(args.dataset_name, None)
|
||||
if url is None:
|
||||
print('Only support COCO, VOC, balloon,and LVIS now!')
|
||||
return
|
||||
download(
|
||||
url,
|
||||
dir=path,
|
||||
unzip=args.unzip,
|
||||
delete=args.delete,
|
||||
threads=args.threads)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue