mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
Support lazy loading in dataset
Add an option for lazy loading, which will parse the data when using it. Specifically, when calling `dataset.query`, it will start to parse the query data, and `train` and `gallery` will not be parsed in advanced. This is useful when you want to run on test dataset but train dataset is large. See examples in market1501.py
This commit is contained in:
parent
d792a69f3f
commit
10b04b75ff
@ -19,9 +19,9 @@ class Dataset(object):
|
||||
This is the base class for ``ImageDataset`` and ``VideoDataset``.
|
||||
|
||||
Args:
|
||||
train (list): contains tuples of (img_path(s), pid, camid).
|
||||
query (list): contains tuples of (img_path(s), pid, camid).
|
||||
gallery (list): contains tuples of (img_path(s), pid, camid).
|
||||
train (list or Callable): contains tuples of (img_path(s), pid, camid).
|
||||
query (list or Callable): contains tuples of (img_path(s), pid, camid).
|
||||
gallery (list or Callable): contains tuples of (img_path(s), pid, camid).
|
||||
transform: transform function.
|
||||
mode (str): 'train', 'query' or 'gallery'.
|
||||
combineall (bool): combines train, query and gallery in a
|
||||
@ -32,17 +32,14 @@ class Dataset(object):
|
||||
|
||||
def __init__(self, train, query, gallery, transform=None, mode='train',
|
||||
combineall=False, verbose=True, **kwargs):
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
self._train = train
|
||||
self._query = query
|
||||
self._gallery = gallery
|
||||
self.transform = transform
|
||||
self.mode = mode
|
||||
self.combineall = combineall
|
||||
self.verbose = verbose
|
||||
|
||||
self.num_train_pids = self.get_num_pids(self.train)
|
||||
self.num_train_cams = self.get_num_cams(self.train)
|
||||
|
||||
if self.combineall:
|
||||
self.combine_all()
|
||||
|
||||
@ -56,6 +53,24 @@ class Dataset(object):
|
||||
raise ValueError('Invalid mode. Got {}, but expected to be '
|
||||
'one of [train | query | gallery]'.format(self.mode))
|
||||
|
||||
@property
|
||||
def train(self):
|
||||
if callable(self._train):
|
||||
self._train = self._train()
|
||||
return self._train
|
||||
|
||||
@property
|
||||
def query(self):
|
||||
if callable(self._query):
|
||||
self._query = self._query()
|
||||
return self._query
|
||||
|
||||
@property
|
||||
def gallery(self):
|
||||
if callable(self._gallery):
|
||||
self._gallery = self._gallery()
|
||||
return self._gallery
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -102,15 +117,14 @@ class Dataset(object):
|
||||
for img_path, pid, camid in data:
|
||||
if pid in self._junk_pids:
|
||||
continue
|
||||
pid = self.dataset_name + "_test_" + str(pid)
|
||||
camid = self.dataset_name + "_test_" + str(camid)
|
||||
pid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(pid)
|
||||
camid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(camid)
|
||||
combined.append((img_path, pid, camid))
|
||||
|
||||
_combine_data(self.query)
|
||||
_combine_data(self.gallery)
|
||||
|
||||
self.train = combined
|
||||
self.num_train_pids = self.get_num_pids(self.train)
|
||||
self._train = combined
|
||||
|
||||
def check_before_run(self, required_files):
|
||||
"""Checks if required files exist before going deeper.
|
||||
@ -134,9 +148,6 @@ class ImageDataset(Dataset):
|
||||
data in each batch has shape (batch_size, channel, height, width).
|
||||
"""
|
||||
|
||||
def __init__(self, train, query, gallery, **kwargs):
|
||||
super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def show_train(self):
|
||||
num_train_pids, num_train_cams = self.parse_data(self.train)
|
||||
|
||||
|
@ -62,11 +62,10 @@ class Market1501(ImageDataset):
|
||||
required_files.append(self.extra_gallery_dir)
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train = self.process_dir(self.train_dir)
|
||||
query = self.process_dir(self.query_dir, is_train=False)
|
||||
gallery = self.process_dir(self.gallery_dir, is_train=False)
|
||||
if self.market1501_500k:
|
||||
gallery += self.process_dir(self.extra_gallery_dir, is_train=False)
|
||||
train = lambda: self.process_dir(self.train_dir)
|
||||
query = lambda: self.process_dir(self.query_dir, is_train=False)
|
||||
gallery = lambda: self.process_dir(self.gallery_dir, is_train=False) + \
|
||||
(self.process_dir(self.extra_gallery_dir, is_train=False) if self.market1501_500k else [])
|
||||
|
||||
super(Market1501, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user