51 lines
1.5 KiB
Python
51 lines
1.5 KiB
Python
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
|
|
|
from .builder import DATASETS
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class ConcatDataset(_ConcatDataset):
|
|
"""A wrapper of concatenated dataset.
|
|
|
|
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
|
|
concat the group flag for image aspect ratio.
|
|
|
|
Args:
|
|
datasets (list[:obj:`Dataset`]): A list of datasets.
|
|
"""
|
|
|
|
def __init__(self, datasets):
|
|
super(ConcatDataset, self).__init__(datasets)
|
|
self.CLASSES = datasets[0].CLASSES
|
|
self.PALETTE = datasets[0].PALETTE
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class RepeatDataset(object):
|
|
"""A wrapper of repeated dataset.
|
|
|
|
The length of repeated dataset will be `times` larger than the original
|
|
dataset. This is useful when the data loading time is long but the dataset
|
|
is small. Using RepeatDataset can reduce the data loading time between
|
|
epochs.
|
|
|
|
Args:
|
|
dataset (:obj:`Dataset`): The dataset to be repeated.
|
|
times (int): Repeat times.
|
|
"""
|
|
|
|
def __init__(self, dataset, times):
|
|
self.dataset = dataset
|
|
self.times = times
|
|
self.CLASSES = dataset.CLASSES
|
|
self.PALETTE = dataset.PALETTE
|
|
self._ori_len = len(self.dataset)
|
|
|
|
def __getitem__(self, idx):
|
|
"""Get item from original dataset."""
|
|
return self.dataset[idx % self._ori_len]
|
|
|
|
def __len__(self):
|
|
"""The length is multiplied by ``times``"""
|
|
return self.times * self._ori_len
|