mirror of https://github.com/alibaba/EasyCV.git
49 lines
1.5 KiB
Python
49 lines
1.5 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import bisect
|
|
|
|
from easycv.datasets.builder import build_datasource
|
|
from easycv.datasets.registry import DATASOURCES
|
|
|
|
|
|
@DATASOURCES.register_module
|
|
class SourceConcat(object):
|
|
"""Concat multi data source config.
|
|
"""
|
|
|
|
def __init__(self, data_source_list):
|
|
assert isinstance(data_source_list, (list, tuple)), \
|
|
'data_source_list must be a config list'
|
|
assert len(data_source_list) > 0, \
|
|
'data_source_list should not be an empty list'
|
|
self.source_type = [
|
|
source_cfg['type'] for source_cfg in data_source_list
|
|
]
|
|
self.data_sources = [
|
|
build_datasource(data_source_cfg)
|
|
for data_source_cfg in data_source_list
|
|
]
|
|
self.cumsum_length_list = self.cumsum_length()
|
|
|
|
def get_length(self):
|
|
return self.cumsum_length_list[-1]
|
|
|
|
def __len__(self):
|
|
return self.get_length()
|
|
|
|
def get_sample(self, idx):
|
|
dataset_idx = bisect.bisect_right(self.cumsum_length_list, idx)
|
|
if dataset_idx == 0:
|
|
sample_idx = idx
|
|
else:
|
|
sample_idx = idx - self.cumsum_length_list[dataset_idx - 1]
|
|
|
|
return self.data_sources[dataset_idx].get_sample(sample_idx)
|
|
|
|
def cumsum_length(self):
|
|
len_cumsum_list, idx = [], 0
|
|
for ds in self.data_sources:
|
|
ds_l = ds.get_length()
|
|
len_cumsum_list.append(ds_l + idx)
|
|
idx += ds_l
|
|
return len_cumsum_list
|