Support lazy loading in dataset

Add an option for lazy loading, which will parse the data when using it. Specifically, when calling `dataset.query`, it will start to parse the query data, and `train` and `gallery` will not be parsed in advanced. This is useful when you want to run on test dataset but train dataset is large.

See examples in market1501.py
This commit is contained in:
liaoxingyu 2021-06-17 16:59:02 +08:00
parent d792a69f3f
commit 10b04b75ff
2 changed files with 31 additions and 21 deletions

View File

@ -19,9 +19,9 @@ class Dataset(object):
This is the base class for ``ImageDataset`` and ``VideoDataset``.
Args:
train (list): contains tuples of (img_path(s), pid, camid).
query (list): contains tuples of (img_path(s), pid, camid).
gallery (list): contains tuples of (img_path(s), pid, camid).
train (list or Callable): contains tuples of (img_path(s), pid, camid).
query (list or Callable): contains tuples of (img_path(s), pid, camid).
gallery (list or Callable): contains tuples of (img_path(s), pid, camid).
transform: transform function.
mode (str): 'train', 'query' or 'gallery'.
combineall (bool): combines train, query and gallery in a
@ -32,17 +32,14 @@ class Dataset(object):
def __init__(self, train, query, gallery, transform=None, mode='train',
combineall=False, verbose=True, **kwargs):
self.train = train
self.query = query
self.gallery = gallery
self._train = train
self._query = query
self._gallery = gallery
self.transform = transform
self.mode = mode
self.combineall = combineall
self.verbose = verbose
self.num_train_pids = self.get_num_pids(self.train)
self.num_train_cams = self.get_num_cams(self.train)
if self.combineall:
self.combine_all()
@ -56,6 +53,24 @@ class Dataset(object):
raise ValueError('Invalid mode. Got {}, but expected to be '
'one of [train | query | gallery]'.format(self.mode))
@property
def train(self):
if callable(self._train):
self._train = self._train()
return self._train
@property
def query(self):
if callable(self._query):
self._query = self._query()
return self._query
@property
def gallery(self):
if callable(self._gallery):
self._gallery = self._gallery()
return self._gallery
def __getitem__(self, index):
raise NotImplementedError
@ -102,15 +117,14 @@ class Dataset(object):
for img_path, pid, camid in data:
if pid in self._junk_pids:
continue
pid = self.dataset_name + "_test_" + str(pid)
camid = self.dataset_name + "_test_" + str(camid)
pid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(pid)
camid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(camid)
combined.append((img_path, pid, camid))
_combine_data(self.query)
_combine_data(self.gallery)
self.train = combined
self.num_train_pids = self.get_num_pids(self.train)
self._train = combined
def check_before_run(self, required_files):
"""Checks if required files exist before going deeper.
@ -134,9 +148,6 @@ class ImageDataset(Dataset):
data in each batch has shape (batch_size, channel, height, width).
"""
def __init__(self, train, query, gallery, **kwargs):
super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
def show_train(self):
num_train_pids, num_train_cams = self.parse_data(self.train)

View File

@ -62,11 +62,10 @@ class Market1501(ImageDataset):
required_files.append(self.extra_gallery_dir)
self.check_before_run(required_files)
train = self.process_dir(self.train_dir)
query = self.process_dir(self.query_dir, is_train=False)
gallery = self.process_dir(self.gallery_dir, is_train=False)
if self.market1501_500k:
gallery += self.process_dir(self.extra_gallery_dir, is_train=False)
train = lambda: self.process_dir(self.train_dir)
query = lambda: self.process_dir(self.query_dir, is_train=False)
gallery = lambda: self.process_dir(self.gallery_dir, is_train=False) + \
(self.process_dir(self.extra_gallery_dir, is_train=False) if self.market1501_500k else [])
super(Market1501, self).__init__(train, query, gallery, **kwargs)