40 lines
1.0 KiB
Python
40 lines
1.0 KiB
Python
|
import copy
|
||
|
from abc import ABCMeta, abstractmethod
|
||
|
|
||
|
from torch.utils.data import Dataset
|
||
|
|
||
|
from .pipelines import Compose
|
||
|
|
||
|
|
||
|
class BaseDataset(Dataset, metaclass=ABCMeta):
|
||
|
|
||
|
def __init__(self, ann_file, pipeline, data_prefix, test_mode):
|
||
|
super(BaseDataset, self).__init__()
|
||
|
|
||
|
self.ann_file = ann_file
|
||
|
self.data_prefix = data_prefix
|
||
|
self.test_mode = test_mode
|
||
|
self.pipeline = Compose(pipeline)
|
||
|
self.data_infos = self.load_annotations()
|
||
|
|
||
|
@abstractmethod
|
||
|
def load_annotations(self):
|
||
|
pass
|
||
|
|
||
|
def prepare_train_data(self, idx):
|
||
|
results = copy.deepcopy(self.data_infos[idx])
|
||
|
return self.pipeline(results)
|
||
|
|
||
|
def prepare_test_data(self, idx):
|
||
|
results = copy.deepcopy(self.data_infos[idx])
|
||
|
return self.pipeline(results)
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.data_infos)
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
if self.test_mode:
|
||
|
return self.prepare_train_data(idx)
|
||
|
else:
|
||
|
return self.prepare_test_data(idx)
|