From 10b04b75ff1ff9e3f412e6d1fc54f391f68b33d2 Mon Sep 17 00:00:00 2001 From: liaoxingyu <sherlockliao01@gmail.com> Date: Thu, 17 Jun 2021 16:59:02 +0800 Subject: [PATCH] Support lazy loading in dataset Add an option for lazy loading, which will parse the data when using it. Specifically, when calling `dataset.query`, it will start to parse the query data, and `train` and `gallery` will not be parsed in advanced. This is useful when you want to run on test dataset but train dataset is large. See examples in market1501.py --- fastreid/data/datasets/bases.py | 43 +++++++++++++++++----------- fastreid/data/datasets/market1501.py | 9 +++--- 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/fastreid/data/datasets/bases.py b/fastreid/data/datasets/bases.py index 8d731e6..4c68837 100644 --- a/fastreid/data/datasets/bases.py +++ b/fastreid/data/datasets/bases.py @@ -19,9 +19,9 @@ class Dataset(object): This is the base class for ``ImageDataset`` and ``VideoDataset``. Args: - train (list): contains tuples of (img_path(s), pid, camid). - query (list): contains tuples of (img_path(s), pid, camid). - gallery (list): contains tuples of (img_path(s), pid, camid). + train (list or Callable): contains tuples of (img_path(s), pid, camid). + query (list or Callable): contains tuples of (img_path(s), pid, camid). + gallery (list or Callable): contains tuples of (img_path(s), pid, camid). transform: transform function. mode (str): 'train', 'query' or 'gallery'. combineall (bool): combines train, query and gallery in a @@ -32,17 +32,14 @@ class Dataset(object): def __init__(self, train, query, gallery, transform=None, mode='train', combineall=False, verbose=True, **kwargs): - self.train = train - self.query = query - self.gallery = gallery + self._train = train + self._query = query + self._gallery = gallery self.transform = transform self.mode = mode self.combineall = combineall self.verbose = verbose - self.num_train_pids = self.get_num_pids(self.train) - self.num_train_cams = self.get_num_cams(self.train) - if self.combineall: self.combine_all() @@ -56,6 +53,24 @@ class Dataset(object): raise ValueError('Invalid mode. Got {}, but expected to be ' 'one of [train | query | gallery]'.format(self.mode)) + @property + def train(self): + if callable(self._train): + self._train = self._train() + return self._train + + @property + def query(self): + if callable(self._query): + self._query = self._query() + return self._query + + @property + def gallery(self): + if callable(self._gallery): + self._gallery = self._gallery() + return self._gallery + def __getitem__(self, index): raise NotImplementedError @@ -102,15 +117,14 @@ class Dataset(object): for img_path, pid, camid in data: if pid in self._junk_pids: continue - pid = self.dataset_name + "_test_" + str(pid) - camid = self.dataset_name + "_test_" + str(camid) + pid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(pid) + camid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(camid) combined.append((img_path, pid, camid)) _combine_data(self.query) _combine_data(self.gallery) - self.train = combined - self.num_train_pids = self.get_num_pids(self.train) + self._train = combined def check_before_run(self, required_files): """Checks if required files exist before going deeper. @@ -134,9 +148,6 @@ class ImageDataset(Dataset): data in each batch has shape (batch_size, channel, height, width). """ - def __init__(self, train, query, gallery, **kwargs): - super(ImageDataset, self).__init__(train, query, gallery, **kwargs) - def show_train(self): num_train_pids, num_train_cams = self.parse_data(self.train) diff --git a/fastreid/data/datasets/market1501.py b/fastreid/data/datasets/market1501.py index d1968af..1c53cd9 100644 --- a/fastreid/data/datasets/market1501.py +++ b/fastreid/data/datasets/market1501.py @@ -62,11 +62,10 @@ class Market1501(ImageDataset): required_files.append(self.extra_gallery_dir) self.check_before_run(required_files) - train = self.process_dir(self.train_dir) - query = self.process_dir(self.query_dir, is_train=False) - gallery = self.process_dir(self.gallery_dir, is_train=False) - if self.market1501_500k: - gallery += self.process_dir(self.extra_gallery_dir, is_train=False) + train = lambda: self.process_dir(self.train_dir) + query = lambda: self.process_dir(self.query_dir, is_train=False) + gallery = lambda: self.process_dir(self.gallery_dir, is_train=False) + \ + (self.process_dir(self.extra_gallery_dir, is_train=False) if self.market1501_500k else []) super(Market1501, self).__init__(train, query, gallery, **kwargs)