Support lazy loading in dataset

Add an option for lazy loading, which will parse the data when using it. Specifically, when calling `dataset.query`, it will start to parse the query data, and `train` and `gallery` will not be parsed in advanced. This is useful when you want to run on test dataset but train dataset is large.

See examples in market1501.py
pull/525/head
liaoxingyu 2021-06-17 16:59:02 +08:00
parent d792a69f3f
commit 10b04b75ff
2 changed files with 31 additions and 21 deletions

View File

@ -19,9 +19,9 @@ class Dataset(object):
This is the base class for ``ImageDataset`` and ``VideoDataset``.
Args:
train (list): contains tuples of (img_path(s), pid, camid).
query (list): contains tuples of (img_path(s), pid, camid).
gallery (list): contains tuples of (img_path(s), pid, camid).
train (list or Callable): contains tuples of (img_path(s), pid, camid).
query (list or Callable): contains tuples of (img_path(s), pid, camid).
gallery (list or Callable): contains tuples of (img_path(s), pid, camid).
transform: transform function.
mode (str): 'train', 'query' or 'gallery'.
combineall (bool): combines train, query and gallery in a
@ -32,17 +32,14 @@ class Dataset(object):
def __init__(self, train, query, gallery, transform=None, mode='train',
combineall=False, verbose=True, **kwargs):
self.train = train
self.query = query
self.gallery = gallery
self._train = train
self._query = query
self._gallery = gallery
self.transform = transform
self.mode = mode
self.combineall = combineall
self.verbose = verbose
self.num_train_pids = self.get_num_pids(self.train)
self.num_train_cams = self.get_num_cams(self.train)
if self.combineall:
self.combine_all()
@ -56,6 +53,24 @@ class Dataset(object):
raise ValueError('Invalid mode. Got {}, but expected to be '
'one of [train | query | gallery]'.format(self.mode))
@property
def train(self):
if callable(self._train):
self._train = self._train()
return self._train
@property
def query(self):
if callable(self._query):
self._query = self._query()
return self._query
@property
def gallery(self):
if callable(self._gallery):
self._gallery = self._gallery()
return self._gallery
def __getitem__(self, index):
raise NotImplementedError
@ -102,15 +117,14 @@ class Dataset(object):
for img_path, pid, camid in data:
if pid in self._junk_pids:
continue
pid = self.dataset_name + "_test_" + str(pid)
camid = self.dataset_name + "_test_" + str(camid)
pid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(pid)
camid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(camid)
combined.append((img_path, pid, camid))
_combine_data(self.query)
_combine_data(self.gallery)
self.train = combined
self.num_train_pids = self.get_num_pids(self.train)
self._train = combined
def check_before_run(self, required_files):
"""Checks if required files exist before going deeper.
@ -134,9 +148,6 @@ class ImageDataset(Dataset):
data in each batch has shape (batch_size, channel, height, width).
"""
def __init__(self, train, query, gallery, **kwargs):
super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
def show_train(self):
num_train_pids, num_train_cams = self.parse_data(self.train)

View File

@ -62,11 +62,10 @@ class Market1501(ImageDataset):
required_files.append(self.extra_gallery_dir)
self.check_before_run(required_files)
train = self.process_dir(self.train_dir)
query = self.process_dir(self.query_dir, is_train=False)
gallery = self.process_dir(self.gallery_dir, is_train=False)
if self.market1501_500k:
gallery += self.process_dir(self.extra_gallery_dir, is_train=False)
train = lambda: self.process_dir(self.train_dir)
query = lambda: self.process_dir(self.query_dir, is_train=False)
gallery = lambda: self.process_dir(self.gallery_dir, is_train=False) + \
(self.process_dir(self.extra_gallery_dir, is_train=False) if self.market1501_500k else [])
super(Market1501, self).__init__(train, query, gallery, **kwargs)