mirror of https://github.com/JDAI-CV/fast-reid.git
Support lazy loading in dataset
Add an option for lazy loading, which will parse the data when using it. Specifically, when calling `dataset.query`, it will start to parse the query data, and `train` and `gallery` will not be parsed in advanced. This is useful when you want to run on test dataset but train dataset is large. See examples in market1501.pypull/525/head
parent
d792a69f3f
commit
10b04b75ff
|
@ -19,9 +19,9 @@ class Dataset(object):
|
|||
This is the base class for ``ImageDataset`` and ``VideoDataset``.
|
||||
|
||||
Args:
|
||||
train (list): contains tuples of (img_path(s), pid, camid).
|
||||
query (list): contains tuples of (img_path(s), pid, camid).
|
||||
gallery (list): contains tuples of (img_path(s), pid, camid).
|
||||
train (list or Callable): contains tuples of (img_path(s), pid, camid).
|
||||
query (list or Callable): contains tuples of (img_path(s), pid, camid).
|
||||
gallery (list or Callable): contains tuples of (img_path(s), pid, camid).
|
||||
transform: transform function.
|
||||
mode (str): 'train', 'query' or 'gallery'.
|
||||
combineall (bool): combines train, query and gallery in a
|
||||
|
@ -32,17 +32,14 @@ class Dataset(object):
|
|||
|
||||
def __init__(self, train, query, gallery, transform=None, mode='train',
|
||||
combineall=False, verbose=True, **kwargs):
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
self._train = train
|
||||
self._query = query
|
||||
self._gallery = gallery
|
||||
self.transform = transform
|
||||
self.mode = mode
|
||||
self.combineall = combineall
|
||||
self.verbose = verbose
|
||||
|
||||
self.num_train_pids = self.get_num_pids(self.train)
|
||||
self.num_train_cams = self.get_num_cams(self.train)
|
||||
|
||||
if self.combineall:
|
||||
self.combine_all()
|
||||
|
||||
|
@ -56,6 +53,24 @@ class Dataset(object):
|
|||
raise ValueError('Invalid mode. Got {}, but expected to be '
|
||||
'one of [train | query | gallery]'.format(self.mode))
|
||||
|
||||
@property
|
||||
def train(self):
|
||||
if callable(self._train):
|
||||
self._train = self._train()
|
||||
return self._train
|
||||
|
||||
@property
|
||||
def query(self):
|
||||
if callable(self._query):
|
||||
self._query = self._query()
|
||||
return self._query
|
||||
|
||||
@property
|
||||
def gallery(self):
|
||||
if callable(self._gallery):
|
||||
self._gallery = self._gallery()
|
||||
return self._gallery
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -102,15 +117,14 @@ class Dataset(object):
|
|||
for img_path, pid, camid in data:
|
||||
if pid in self._junk_pids:
|
||||
continue
|
||||
pid = self.dataset_name + "_test_" + str(pid)
|
||||
camid = self.dataset_name + "_test_" + str(camid)
|
||||
pid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(pid)
|
||||
camid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(camid)
|
||||
combined.append((img_path, pid, camid))
|
||||
|
||||
_combine_data(self.query)
|
||||
_combine_data(self.gallery)
|
||||
|
||||
self.train = combined
|
||||
self.num_train_pids = self.get_num_pids(self.train)
|
||||
self._train = combined
|
||||
|
||||
def check_before_run(self, required_files):
|
||||
"""Checks if required files exist before going deeper.
|
||||
|
@ -134,9 +148,6 @@ class ImageDataset(Dataset):
|
|||
data in each batch has shape (batch_size, channel, height, width).
|
||||
"""
|
||||
|
||||
def __init__(self, train, query, gallery, **kwargs):
|
||||
super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def show_train(self):
|
||||
num_train_pids, num_train_cams = self.parse_data(self.train)
|
||||
|
||||
|
|
|
@ -62,11 +62,10 @@ class Market1501(ImageDataset):
|
|||
required_files.append(self.extra_gallery_dir)
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train = self.process_dir(self.train_dir)
|
||||
query = self.process_dir(self.query_dir, is_train=False)
|
||||
gallery = self.process_dir(self.gallery_dir, is_train=False)
|
||||
if self.market1501_500k:
|
||||
gallery += self.process_dir(self.extra_gallery_dir, is_train=False)
|
||||
train = lambda: self.process_dir(self.train_dir)
|
||||
query = lambda: self.process_dir(self.query_dir, is_train=False)
|
||||
gallery = lambda: self.process_dir(self.gallery_dir, is_train=False) + \
|
||||
(self.process_dir(self.extra_gallery_dir, is_train=False) if self.market1501_500k else [])
|
||||
|
||||
super(Market1501, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue