add --combineall
parent
402e054110
commit
983c9e51e5
4
args.py
4
args.py
|
@ -23,6 +23,8 @@ def argument_parser():
|
|||
help='width of an image')
|
||||
parser.add_argument('--train-sampler', type=str, default='RandomSampler',
|
||||
help='sampler for trainloader')
|
||||
parser.add_argument('--combineall', action='store_true',
|
||||
help='combine all data in a dataset (train+query+gallery) for training')
|
||||
|
||||
# ************************************************************
|
||||
# Data augmentation
|
||||
|
@ -191,6 +193,7 @@ def image_dataset_kwargs(parsed_args):
|
|||
'split_id': parsed_args.split_id,
|
||||
'height': parsed_args.height,
|
||||
'width': parsed_args.width,
|
||||
'combineall': parsed_args.combineall,
|
||||
'train_batch_size': parsed_args.train_batch_size,
|
||||
'test_batch_size': parsed_args.test_batch_size,
|
||||
'workers': parsed_args.workers,
|
||||
|
@ -217,6 +220,7 @@ def video_dataset_kwargs(parsed_args):
|
|||
'split_id': parsed_args.split_id,
|
||||
'height': parsed_args.height,
|
||||
'width': parsed_args.width,
|
||||
'combineall': parsed_args.combineall,
|
||||
'train_batch_size': parsed_args.train_batch_size,
|
||||
'test_batch_size': parsed_args.test_batch_size,
|
||||
'workers': parsed_args.workers,
|
||||
|
|
|
@ -19,6 +19,7 @@ class BaseDataManager(object):
|
|||
split_id=0,
|
||||
height=256,
|
||||
width=128,
|
||||
combineall=False, # combine all data in a dataset for training
|
||||
train_batch_size=32,
|
||||
test_batch_size=100,
|
||||
workers=4,
|
||||
|
@ -36,6 +37,7 @@ class BaseDataManager(object):
|
|||
self.split_id = split_id
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.combineall = combineall
|
||||
self.train_batch_size = train_batch_size
|
||||
self.test_batch_size = test_batch_size
|
||||
self.workers = workers
|
||||
|
@ -97,6 +99,7 @@ class ImageDataManager(BaseDataManager):
|
|||
root=self.root,
|
||||
name=name,
|
||||
split_id=self.split_id,
|
||||
combineall=self.combineall,
|
||||
cuhk03_labeled=self.cuhk03_labeled,
|
||||
cuhk03_classic_split=self.cuhk03_classic_split,
|
||||
market1501_500k=self.market1501_500k
|
||||
|
@ -136,6 +139,7 @@ class ImageDataManager(BaseDataManager):
|
|||
root=self.root,
|
||||
name=name,
|
||||
split_id=self.split_id,
|
||||
combineall=self.combineall,
|
||||
cuhk03_labeled=self.cuhk03_labeled,
|
||||
cuhk03_classic_split=self.cuhk03_classic_split,
|
||||
market1501_500k=self.market1501_500k
|
||||
|
@ -196,7 +200,7 @@ class VideoDataManager(BaseDataManager):
|
|||
self._num_train_cams = 0
|
||||
|
||||
for name in self.source_names:
|
||||
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id)
|
||||
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id, combineall=self.combineall)
|
||||
|
||||
for img_paths, pid, camid in dataset.train:
|
||||
pid += self._num_train_pids
|
||||
|
@ -251,7 +255,7 @@ class VideoDataManager(BaseDataManager):
|
|||
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in target_names}
|
||||
|
||||
for name in self.target_names:
|
||||
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id)
|
||||
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id, combineall=self.combineall,)
|
||||
|
||||
self.testloader_dict[name]['query'] = DataLoader(
|
||||
VideoDataset(
|
||||
|
|
|
@ -4,6 +4,7 @@ from __future__ import print_function
|
|||
import os
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import copy
|
||||
|
||||
|
||||
class BaseDataset(object):
|
||||
|
@ -32,13 +33,39 @@ class BaseDataset(object):
|
|||
def get_num_cams(self, data):
|
||||
return self.extract_data_info(data)[2]
|
||||
|
||||
def init_attributes(self, train, query, gallery):
|
||||
def init_attributes(self, train, query, gallery, combineall=False, **kwargs):
|
||||
"""Initialize class attributes
|
||||
|
||||
Args:
|
||||
train (list): contains a list of (img_path, pid, camid)
|
||||
query (list): contains a list of (img_path, pid, camid)
|
||||
gallery (list): contains a list of (img_path, pid, camid)
|
||||
combineall (bool): if set to True, combine all data for training, default is False
|
||||
|
||||
Notes:
|
||||
This method has to be called (at the end) in each dataset class.
|
||||
"""
|
||||
self._train = train
|
||||
self._query = query
|
||||
self._gallery = gallery
|
||||
self._num_train_pids = self.get_num_pids(train)
|
||||
self._num_train_cams = self.get_num_cams(train)
|
||||
|
||||
if combineall:
|
||||
self._train = self.combine_all(train, query, gallery)
|
||||
self._num_train_pids = self.get_num_pids(self.train)
|
||||
|
||||
def combine_all(self, train, query, gallery):
|
||||
"""Combine all data for training
|
||||
|
||||
Notes:
|
||||
1. In general, we assume that all query identities appear in gallery set.
|
||||
2. All pids in train have been relabeled (starts from 0)
|
||||
3. pid=0 (background) and pid=-1 (junk) are discarded.
|
||||
4. Camera views remain the same across train, query and gallery.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def train(self):
|
||||
# train list containing (img_path, pid, camid)
|
||||
|
@ -72,17 +99,39 @@ class BaseImageDataset(BaseDataset):
|
|||
"""Base class of image-reid dataset"""
|
||||
|
||||
def extract_data_info(self, data):
|
||||
pids, cams = [], []
|
||||
pids = set()
|
||||
cams = set()
|
||||
for _, pid, camid in data:
|
||||
pids += [pid]
|
||||
cams += [camid]
|
||||
pids = set(pids)
|
||||
cams = set(cams)
|
||||
pids.add(pid)
|
||||
cams.add(camid)
|
||||
num_pids = len(pids)
|
||||
num_cams = len(cams)
|
||||
num_imgs = len(data)
|
||||
return num_pids, num_imgs, num_cams
|
||||
|
||||
def combine_all(self, train, query, gallery):
|
||||
combined = copy.deepcopy(train)
|
||||
|
||||
# relabel pids in gallery
|
||||
g_pids = set()
|
||||
for _, pid, _ in gallery:
|
||||
if pid==0 or pid==-1:
|
||||
continue
|
||||
g_pids.add(pid)
|
||||
pid2label = {pid: i for i, pid in enumerate(g_pids)}
|
||||
|
||||
def _combine_data(data):
|
||||
for img_path, pid, camid in data:
|
||||
if pid==0 or pid==-1:
|
||||
continue
|
||||
pid = pid2label[pid] + self.num_train_pids
|
||||
combined.append((img_path, pid, camid))
|
||||
|
||||
_combine_data(query)
|
||||
_combine_data(gallery)
|
||||
|
||||
return combined
|
||||
|
||||
def print_dataset_statistics(self, train, query, gallery):
|
||||
num_train_pids, num_train_imgs, num_train_cams = self.extract_data_info(train)
|
||||
num_query_pids, num_query_imgs, num_query_cams = self.extract_data_info(query)
|
||||
|
@ -102,13 +151,13 @@ class BaseVideoDataset(BaseDataset):
|
|||
"""Base class of video-reid dataset"""
|
||||
|
||||
def extract_data_info(self, data, return_tracklet_stats=False):
|
||||
pids, cams, tracklet_stats = [], [], []
|
||||
pids = set()
|
||||
cams = set()
|
||||
tracklet_stats = []
|
||||
for img_paths, pid, camid in data:
|
||||
pids += [pid]
|
||||
cams += [camid]
|
||||
pids.add(pid)
|
||||
cams.add(camid)
|
||||
tracklet_stats += [len(img_paths)]
|
||||
pids = set(pids)
|
||||
cams = set(cams)
|
||||
num_pids = len(pids)
|
||||
num_cams = len(cams)
|
||||
num_tracklets = len(data)
|
||||
|
|
|
@ -63,10 +63,10 @@ class CUHK01(BaseImageDataset):
|
|||
query = [tuple(item) for item in query]
|
||||
gallery = [tuple(item) for item in gallery]
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
self.init_attributes(train, query, gallery, **kwargs)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
self.print_dataset_statistics(self.train, self.query, self.gallery)
|
||||
|
||||
def extract_file(self):
|
||||
if not osp.exists(self.campus_dir):
|
||||
|
|
|
@ -81,10 +81,10 @@ class CUHK03(BaseImageDataset):
|
|||
query = split['query']
|
||||
gallery = split['gallery']
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
self.init_attributes(train, query, gallery, **kwargs)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
self.print_dataset_statistics(self.train, self.query, self.gallery)
|
||||
|
||||
def preprocess_split(self):
|
||||
"""
|
||||
|
|
|
@ -57,10 +57,10 @@ class DukeMTMCreID(BaseImageDataset):
|
|||
query = self.process_dir(self.query_dir, relabel=False)
|
||||
gallery = self.process_dir(self.gallery_dir, relabel=False)
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
self.init_attributes(train, query, gallery, **kwargs)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
self.print_dataset_statistics(self.train, self.query, self.gallery)
|
||||
|
||||
def download_data(self):
|
||||
if osp.exists(self.dataset_dir):
|
||||
|
|
|
@ -60,10 +60,10 @@ class DukeMTMCVidReID(BaseVideoDataset):
|
|||
query = self.process_dir(self.query_dir, self.split_query_json_path, relabel=False)
|
||||
gallery = self.process_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False)
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
self.init_attributes(train, query, gallery, **kwargs)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
self.print_dataset_statistics(self.train, self.query, self.gallery)
|
||||
|
||||
def download_data(self):
|
||||
if osp.exists(self.dataset_dir):
|
||||
|
|
|
@ -67,10 +67,10 @@ class GRID(BaseImageDataset):
|
|||
query = [tuple(item) for item in query]
|
||||
gallery = [tuple(item) for item in gallery]
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
self.init_attributes(train, query, gallery, **kwargs)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
self.print_dataset_statistics(self.train, self.query, self.gallery)
|
||||
|
||||
def download_data(self):
|
||||
if osp.exists(self.dataset_dir):
|
||||
|
|
|
@ -58,10 +58,10 @@ class iLIDS(BaseImageDataset):
|
|||
|
||||
train, query, gallery = self.process_split(split)
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
self.init_attributes(train, query, gallery, **kwargs)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
self.print_dataset_statistics(self.train, self.query, self.gallery)
|
||||
|
||||
def download_data(self):
|
||||
if osp.exists(self.dataset_dir):
|
||||
|
|
|
@ -65,10 +65,10 @@ class iLIDSVID(BaseVideoDataset):
|
|||
query = self.process_data(test_dirs, cam1=True, cam2=False)
|
||||
gallery = self.process_data(test_dirs, cam1=False, cam2=True)
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
self.init_attributes(train, query, gallery, **kwargs)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
self.print_dataset_statistics(self.train, self.query, self.gallery)
|
||||
|
||||
def download_data(self):
|
||||
if osp.exists(self.dataset_dir):
|
||||
|
|
|
@ -57,10 +57,10 @@ class Market1501(BaseImageDataset):
|
|||
if self.market1501_500k:
|
||||
gallery += self.process_dir(self.extra_gallery_dir, relabel=False)
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
self.init_attributes(train, query, gallery, **kwargs)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
self.print_dataset_statistics(self.train, self.query, self.gallery)
|
||||
|
||||
def process_dir(self, dir_path, relabel=False):
|
||||
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
|
||||
|
|
|
@ -66,10 +66,10 @@ class Mars(BaseVideoDataset):
|
|||
query = self.process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len)
|
||||
gallery = self.process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len)
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
self.init_attributes(train, query, gallery, **kwargs)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
self.print_dataset_statistics(self.train, self.query, self.gallery)
|
||||
|
||||
def get_names(self, fpath):
|
||||
names = []
|
||||
|
|
|
@ -80,14 +80,15 @@ class MSMT17(BaseImageDataset):
|
|||
query = self.process_dir(self.test_dir, self.list_query_path)
|
||||
gallery = self.process_dir(self.test_dir, self.list_gallery_path)
|
||||
|
||||
# To fairly compare with published methods, don't use val images for training
|
||||
#train += val
|
||||
#num_train_imgs += num_val_imgs
|
||||
# Note: to fairly compare with published methods on the conventional ReID setting,
|
||||
# do not add val images to the training set.
|
||||
if 'combineall' in kwargs and kwargs['combineall']:
|
||||
train += val
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
self.init_attributes(train, query, gallery, **kwargs)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
self.print_dataset_statistics(self.train, self.query, self.gallery)
|
||||
def process_dir(self, dir_path, list_path):
|
||||
with open(list_path, 'r') as txt:
|
||||
lines = txt.readlines()
|
||||
|
|
|
@ -60,10 +60,10 @@ class PRID(BaseImageDataset):
|
|||
|
||||
train, query, gallery = self.process_split(split)
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
self.init_attributes(train, query, gallery, **kwargs)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
self.print_dataset_statistics(self.train, self.query, self.gallery)
|
||||
|
||||
def prepare_split(self):
|
||||
if not osp.exists(self.split_path):
|
||||
|
|
|
@ -58,10 +58,10 @@ class PRID2011(BaseVideoDataset):
|
|||
query = self.process_dir(test_dirs, cam1=True, cam2=False)
|
||||
gallery = self.process_dir(test_dirs, cam1=False, cam2=True)
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
self.init_attributes(train, query, gallery, **kwargs)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
self.print_dataset_statistics(self.train, self.query, self.gallery)
|
||||
|
||||
def process_dir(self, dirnames, cam1=True, cam2=True):
|
||||
tracklets = []
|
||||
|
|
|
@ -65,10 +65,10 @@ class PRID450S(BaseImageDataset):
|
|||
query = [tuple(item) for item in query]
|
||||
gallery = [tuple(item) for item in gallery]
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
self.init_attributes(train, query, gallery, **kwargs)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
self.print_dataset_statistics(self.train, self.query, self.gallery)
|
||||
|
||||
def download_data(self):
|
||||
if osp.exists(self.dataset_dir):
|
||||
|
|
|
@ -54,10 +54,10 @@ class SenseReID(BaseImageDataset):
|
|||
gallery = self.process_dir(self.gallery_dir)
|
||||
train = copy.deepcopy(query) # dummy variable
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
self.init_attributes(train, query, gallery, **kwargs)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
self.print_dataset_statistics(self.train, self.query, self.gallery)
|
||||
|
||||
def process_dir(self, dir_path):
|
||||
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
|
||||
|
|
|
@ -64,10 +64,10 @@ class VIPeR(BaseImageDataset):
|
|||
query = [tuple(item) for item in query]
|
||||
gallery = [tuple(item) for item in gallery]
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
self.init_attributes(train, query, gallery, **kwargs)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
self.print_dataset_statistics(self.train, self.query, self.gallery)
|
||||
|
||||
def download_data(self):
|
||||
if osp.exists(self.dataset_dir):
|
||||
|
|
Loading…
Reference in New Issue