mirror of https://github.com/alibaba/EasyCV.git
125 lines
4.1 KiB
Python
125 lines
4.1 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import logging
|
|
import os
|
|
|
|
from easycv.file import io
|
|
from easycv.utils import dist_utils
|
|
|
|
|
|
def get_path_and_index(file_list_or_path):
|
|
if type(file_list_or_path) == str:
|
|
lines = io.open(file_list_or_path).readlines()
|
|
else:
|
|
lines = file_list_or_path
|
|
path = []
|
|
index = []
|
|
for i in lines:
|
|
i = i.strip()
|
|
if i.endswith('.idx'):
|
|
pass
|
|
else:
|
|
path.append(i)
|
|
index.append(i + '.idx')
|
|
return path, index
|
|
|
|
|
|
# multi worker download using oss
|
|
def download_tfrecord(file_list_or_path,
|
|
target_path,
|
|
slice_count=1,
|
|
slice_id=0,
|
|
force=False):
|
|
"""Download data from oss.
|
|
Use the processes on the gpus to slice download, each gpu process downloads part of the data.
|
|
The number of slices is the same as the number of gpu processes.
|
|
Support tfrecord of ImageNet style.
|
|
tfrecord_dir
|
|
|---train1
|
|
|---train1.idx
|
|
|---train2
|
|
|---train2.idx
|
|
|---...
|
|
|
|
Args:
|
|
file_list_or_path: A list of absolute data path or a path str
|
|
type(file_list) == list means this is the list
|
|
type(file_list) == str means open(file_list).readlines()
|
|
target_path: A str, download path
|
|
slice_count: Download worker num
|
|
slice_id : Download worker ID
|
|
force: If false, skip download if the file already exists in the target path.
|
|
If true, recopy and replace the original file.
|
|
|
|
Returns:
|
|
path: list of str, download tfrecord path
|
|
index_path: list of str, download tfrecord idx path
|
|
"""
|
|
with dist_utils.dist_zero_exec():
|
|
if not os.path.exists(target_path):
|
|
os.makedirs(target_path)
|
|
|
|
logging.info(f'num gpu(slice_count): {slice_count}')
|
|
|
|
if isinstance(file_list_or_path, list):
|
|
all_file_list = file_list_or_path
|
|
else:
|
|
with io.open(file_list_or_path, 'r') as f:
|
|
lines = f.readlines()
|
|
all_file_list = [i.strip() for i in lines]
|
|
|
|
all_data_list = [
|
|
all_file_list[i] for i in range(len(all_file_list))
|
|
if not all_file_list[i].endswith('.idx')
|
|
]
|
|
all_index_list = [
|
|
all_file_list[i] for i in range(len(all_file_list))
|
|
if all_file_list[i].endswith('.idx')
|
|
]
|
|
if not all_index_list:
|
|
all_index_list = [i + '.idx' for i in all_data_list]
|
|
|
|
idx = 0
|
|
for data_path in all_data_list:
|
|
# split data list to target worker
|
|
if idx % slice_count == slice_id:
|
|
target_file = os.path.join(target_path,
|
|
os.path.split(data_path)[-1])
|
|
if not force and io.exists(target_file):
|
|
logging.info('%s already exists, skip download!' % target_file)
|
|
continue
|
|
io.copy(data_path, target_file)
|
|
logging.info('Finished download file: %s' % data_path)
|
|
idx += 1
|
|
|
|
idx = 0
|
|
for idx_path in all_index_list:
|
|
# split data list to target worker
|
|
if idx % slice_count == slice_id:
|
|
target_file = os.path.join(target_path,
|
|
os.path.split(idx_path)[-1])
|
|
if not force and io.exists(target_file):
|
|
logging.info('%s already exists, skip download!' % target_file)
|
|
continue
|
|
io.copy(idx_path, target_file)
|
|
logging.info('Finished download file: %s' % idx_path)
|
|
idx += 1
|
|
|
|
logging.info('rank %s finish downloads!' % slice_id)
|
|
|
|
dist_utils.barrier()
|
|
|
|
# return all data list
|
|
new_path = []
|
|
for data_path in all_data_list:
|
|
target_file = os.path.join(target_path, os.path.split(data_path)[-1])
|
|
new_path.append(target_file)
|
|
all_data_list = new_path
|
|
|
|
new_index_path = []
|
|
for idx_path in all_index_list:
|
|
target_file = os.path.join(target_path, os.path.split(idx_path)[-1])
|
|
new_index_path.append(target_file)
|
|
all_index_list = new_index_path
|
|
|
|
return all_data_list, all_index_list
|