diff --git a/data_manager.py b/data_manager.py index 67e90ab..b62be86 100755 --- a/data_manager.py +++ b/data_manager.py @@ -114,7 +114,7 @@ class CUHK03(object): # identities: 1360 # images: 13164 # cameras: 6 - # splits: 20 + # splits: 20 (classic) Args: split_id (int): split index (default: 0) @@ -369,13 +369,76 @@ class DukeMTMCreID(object): URL: https://github.com/layumi/DukeMTMC-reID_evaluation Dataset statistics: - # identities: - # images: - # cameras: - # splits: + # identities: 1404 (train + query) + # images:16522 (train) + 2228 (query) + 17661 (gallery) + # cameras: 8 """ + root = './data/dukemtmc-reid' + train_dir = osp.join(root, 'DukeMTMC-reID/bounding_box_train') + query_dir = osp.join(root, 'DukeMTMC-reID/query') + gallery_dir = osp.join(root, 'DukeMTMC-reID/bounding_box_test') + def __init__(self): - pass + self._check_before_run() + + train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True) + query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False) + gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False) + num_total_pids = num_train_pids + num_query_pids + num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs + + print("=> DukeMTMC-reID loaded") + print("Dataset statistics:") + print(" ------------------------------") + print(" subset | # ids | # images") + print(" ------------------------------") + print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) + print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) + print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) + print(" ------------------------------") + print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) + print(" ------------------------------") + + self.train = train + self.query = query + self.gallery = gallery + + self.num_train_pids = num_train_pids + self.num_query_pids = num_query_pids + self.num_gallery_pids = num_gallery_pids + + def _check_before_run(self): + """Check if all files are available before going deeper""" + if not osp.exists(self.root): + raise RuntimeError("'{}' is not available".format(self.root)) + if not osp.exists(self.train_dir): + raise RuntimeError("'{}' is not available".format(self.train_dir)) + if not osp.exists(self.query_dir): + raise RuntimeError("'{}' is not available".format(self.query_dir)) + if not osp.exists(self.gallery_dir): + raise RuntimeError("'{}' is not available".format(self.gallery_dir)) + + def _process_dir(self, dir_path, relabel=False): + img_paths = glob.glob(osp.join(dir_path, '*.jpg')) + pattern = re.compile(r'([-\d]+)_c(\d)') + + pid_container = set() + for img_path in img_paths: + pid, _ = map(int, pattern.search(img_path).groups()) + pid_container.add(pid) + pid2label = {pid:label for label, pid in enumerate(pid_container)} + + dataset = [] + for img_path in img_paths: + pid, camid = map(int, pattern.search(img_path).groups()) + assert 1 <= camid <= 8 + camid -= 1 # index starts from 0 + if relabel: pid = pid2label[pid] + dataset.append((img_path, pid, camid)) + + num_pids = len(pid_container) + num_imgs = len(dataset) + return dataset, num_pids, num_imgs """Video ReID""" @@ -794,6 +857,7 @@ class PRID(object): __factory = { 'market1501': Market1501, 'cuhk03': CUHK03, + 'dukemtmcreid': DukeMTMCreID, 'mars': Mars, 'ilidsvid': iLIDSVID, 'prid': PRID,