Support lazy loading in dataset

Add an option for lazy loading, which will parse the data when using it. Specifically, when calling `dataset.query`, it will start to parse the query data, and `train` and `gallery` will not be parsed in advanced. This is useful when you want to run on test dataset but train dataset is large.

See examples in market1501.py
This commit is contained in:
liaoxingyu 2021-06-17 16:59:02 +08:00
parent d792a69f3f
commit 10b04b75ff
2 changed files with 31 additions and 21 deletions

View File

@ -19,9 +19,9 @@ class Dataset(object):
This is the base class for ``ImageDataset`` and ``VideoDataset``. This is the base class for ``ImageDataset`` and ``VideoDataset``.
Args: Args:
train (list): contains tuples of (img_path(s), pid, camid). train (list or Callable): contains tuples of (img_path(s), pid, camid).
query (list): contains tuples of (img_path(s), pid, camid). query (list or Callable): contains tuples of (img_path(s), pid, camid).
gallery (list): contains tuples of (img_path(s), pid, camid). gallery (list or Callable): contains tuples of (img_path(s), pid, camid).
transform: transform function. transform: transform function.
mode (str): 'train', 'query' or 'gallery'. mode (str): 'train', 'query' or 'gallery'.
combineall (bool): combines train, query and gallery in a combineall (bool): combines train, query and gallery in a
@ -32,17 +32,14 @@ class Dataset(object):
def __init__(self, train, query, gallery, transform=None, mode='train', def __init__(self, train, query, gallery, transform=None, mode='train',
combineall=False, verbose=True, **kwargs): combineall=False, verbose=True, **kwargs):
self.train = train self._train = train
self.query = query self._query = query
self.gallery = gallery self._gallery = gallery
self.transform = transform self.transform = transform
self.mode = mode self.mode = mode
self.combineall = combineall self.combineall = combineall
self.verbose = verbose self.verbose = verbose
self.num_train_pids = self.get_num_pids(self.train)
self.num_train_cams = self.get_num_cams(self.train)
if self.combineall: if self.combineall:
self.combine_all() self.combine_all()
@ -56,6 +53,24 @@ class Dataset(object):
raise ValueError('Invalid mode. Got {}, but expected to be ' raise ValueError('Invalid mode. Got {}, but expected to be '
'one of [train | query | gallery]'.format(self.mode)) 'one of [train | query | gallery]'.format(self.mode))
@property
def train(self):
if callable(self._train):
self._train = self._train()
return self._train
@property
def query(self):
if callable(self._query):
self._query = self._query()
return self._query
@property
def gallery(self):
if callable(self._gallery):
self._gallery = self._gallery()
return self._gallery
def __getitem__(self, index): def __getitem__(self, index):
raise NotImplementedError raise NotImplementedError
@ -102,15 +117,14 @@ class Dataset(object):
for img_path, pid, camid in data: for img_path, pid, camid in data:
if pid in self._junk_pids: if pid in self._junk_pids:
continue continue
pid = self.dataset_name + "_test_" + str(pid) pid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(pid)
camid = self.dataset_name + "_test_" + str(camid) camid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(camid)
combined.append((img_path, pid, camid)) combined.append((img_path, pid, camid))
_combine_data(self.query) _combine_data(self.query)
_combine_data(self.gallery) _combine_data(self.gallery)
self.train = combined self._train = combined
self.num_train_pids = self.get_num_pids(self.train)
def check_before_run(self, required_files): def check_before_run(self, required_files):
"""Checks if required files exist before going deeper. """Checks if required files exist before going deeper.
@ -134,9 +148,6 @@ class ImageDataset(Dataset):
data in each batch has shape (batch_size, channel, height, width). data in each batch has shape (batch_size, channel, height, width).
""" """
def __init__(self, train, query, gallery, **kwargs):
super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
def show_train(self): def show_train(self):
num_train_pids, num_train_cams = self.parse_data(self.train) num_train_pids, num_train_cams = self.parse_data(self.train)

View File

@ -62,11 +62,10 @@ class Market1501(ImageDataset):
required_files.append(self.extra_gallery_dir) required_files.append(self.extra_gallery_dir)
self.check_before_run(required_files) self.check_before_run(required_files)
train = self.process_dir(self.train_dir) train = lambda: self.process_dir(self.train_dir)
query = self.process_dir(self.query_dir, is_train=False) query = lambda: self.process_dir(self.query_dir, is_train=False)
gallery = self.process_dir(self.gallery_dir, is_train=False) gallery = lambda: self.process_dir(self.gallery_dir, is_train=False) + \
if self.market1501_500k: (self.process_dir(self.extra_gallery_dir, is_train=False) if self.market1501_500k else [])
gallery += self.process_dir(self.extra_gallery_dir, is_train=False)
super(Market1501, self).__init__(train, query, gallery, **kwargs) super(Market1501, self).__init__(train, query, gallery, **kwargs)