40 lines
1.0 KiB
Python
40 lines
1.0 KiB
Python
import copy
|
|
from abc import ABCMeta, abstractmethod
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
from .pipelines import Compose
|
|
|
|
|
|
class BaseDataset(Dataset, metaclass=ABCMeta):
|
|
|
|
def __init__(self, ann_file, pipeline, data_prefix, test_mode):
|
|
super(BaseDataset, self).__init__()
|
|
|
|
self.ann_file = ann_file
|
|
self.data_prefix = data_prefix
|
|
self.test_mode = test_mode
|
|
self.pipeline = Compose(pipeline)
|
|
self.data_infos = self.load_annotations()
|
|
|
|
@abstractmethod
|
|
def load_annotations(self):
|
|
pass
|
|
|
|
def prepare_train_data(self, idx):
|
|
results = copy.deepcopy(self.data_infos[idx])
|
|
return self.pipeline(results)
|
|
|
|
def prepare_test_data(self, idx):
|
|
results = copy.deepcopy(self.data_infos[idx])
|
|
return self.pipeline(results)
|
|
|
|
def __len__(self):
|
|
return len(self.data_infos)
|
|
|
|
def __getitem__(self, idx):
|
|
if self.test_mode:
|
|
return self.prepare_train_data(idx)
|
|
else:
|
|
return self.prepare_test_data(idx)
|