mirror of https://github.com/alibaba/EasyCV.git
50 lines
1.6 KiB
Python
50 lines
1.6 KiB
Python
import os
|
|
|
|
from mmcv.runner import get_dist_info
|
|
from tqdm import tqdm
|
|
|
|
from easycv.file import io
|
|
|
|
|
|
def split_listfile_byrank(list_file,
|
|
label_balance,
|
|
save_path='data/',
|
|
delimeter=' '):
|
|
rank, world_size = get_dist_info()
|
|
if world_size == 1:
|
|
return list_file
|
|
|
|
lines = io.open(list_file).readlines()
|
|
if not os.path.exists(save_path):
|
|
os.makedirs(save_path)
|
|
|
|
if label_balance:
|
|
label2files = {}
|
|
for l in tqdm(lines):
|
|
label = l.split(delimeter)[-1].strip()
|
|
path = l.split(delimeter)[0]
|
|
if label in label2files.keys():
|
|
label2files[label].append(path)
|
|
else:
|
|
label2files[label] = [path]
|
|
|
|
label_list = list(label2files.keys())
|
|
length = int(len(label_list) / world_size) + 1
|
|
|
|
for idx in tqdm(range(world_size)):
|
|
with io.open(os.path.join(save_path, 'rank_%d.txt' % idx),
|
|
'w') as f:
|
|
for j in range(idx * length, idx * length + length):
|
|
j_label = label_list[j % len(label_list)]
|
|
for p in label2files[j_label]:
|
|
f.write('%s %s\n' % (p, j_label))
|
|
else:
|
|
length = int(len(lines) / world_size) + 1
|
|
for idx in range(world_size):
|
|
with io.open(os.path.join(save_path, 'rank_%d.txt' % idx),
|
|
'w') as f:
|
|
for j in range(idx * length, idx * length + length):
|
|
f.write(lines[j % len(lines)])
|
|
|
|
return os.path.join(save_path, 'rank_%d.txt' % rank)
|