mmpretrain/mmcls/datasets/base_dataset.py

40 lines
1.0 KiB
Python

import copy
from abc import ABCMeta, abstractmethod
from torch.utils.data import Dataset
from .pipelines import Compose
class BaseDataset(Dataset, metaclass=ABCMeta):
def __init__(self, ann_file, pipeline, data_prefix, test_mode):
super(BaseDataset, self).__init__()
self.ann_file = ann_file
self.data_prefix = data_prefix
self.test_mode = test_mode
self.pipeline = Compose(pipeline)
self.data_infos = self.load_annotations()
@abstractmethod
def load_annotations(self):
pass
def prepare_train_data(self, idx):
results = copy.deepcopy(self.data_infos[idx])
return self.pipeline(results)
def prepare_test_data(self, idx):
results = copy.deepcopy(self.data_infos[idx])
return self.pipeline(results)
def __len__(self):
return len(self.data_infos)
def __getitem__(self, idx):
if self.test_mode:
return self.prepare_train_data(idx)
else:
return self.prepare_test_data(idx)