add --combineall

pull/133/head
KaiyangZhou 2019-03-15 17:45:47 +00:00
parent 402e054110
commit 983c9e51e5
18 changed files with 104 additions and 46 deletions

View File

@ -23,6 +23,8 @@ def argument_parser():
help='width of an image')
parser.add_argument('--train-sampler', type=str, default='RandomSampler',
help='sampler for trainloader')
parser.add_argument('--combineall', action='store_true',
help='combine all data in a dataset (train+query+gallery) for training')
# ************************************************************
# Data augmentation
@ -191,6 +193,7 @@ def image_dataset_kwargs(parsed_args):
'split_id': parsed_args.split_id,
'height': parsed_args.height,
'width': parsed_args.width,
'combineall': parsed_args.combineall,
'train_batch_size': parsed_args.train_batch_size,
'test_batch_size': parsed_args.test_batch_size,
'workers': parsed_args.workers,
@ -217,6 +220,7 @@ def video_dataset_kwargs(parsed_args):
'split_id': parsed_args.split_id,
'height': parsed_args.height,
'width': parsed_args.width,
'combineall': parsed_args.combineall,
'train_batch_size': parsed_args.train_batch_size,
'test_batch_size': parsed_args.test_batch_size,
'workers': parsed_args.workers,

View File

@ -19,6 +19,7 @@ class BaseDataManager(object):
split_id=0,
height=256,
width=128,
combineall=False, # combine all data in a dataset for training
train_batch_size=32,
test_batch_size=100,
workers=4,
@ -36,6 +37,7 @@ class BaseDataManager(object):
self.split_id = split_id
self.height = height
self.width = width
self.combineall = combineall
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.workers = workers
@ -97,6 +99,7 @@ class ImageDataManager(BaseDataManager):
root=self.root,
name=name,
split_id=self.split_id,
combineall=self.combineall,
cuhk03_labeled=self.cuhk03_labeled,
cuhk03_classic_split=self.cuhk03_classic_split,
market1501_500k=self.market1501_500k
@ -136,6 +139,7 @@ class ImageDataManager(BaseDataManager):
root=self.root,
name=name,
split_id=self.split_id,
combineall=self.combineall,
cuhk03_labeled=self.cuhk03_labeled,
cuhk03_classic_split=self.cuhk03_classic_split,
market1501_500k=self.market1501_500k
@ -196,7 +200,7 @@ class VideoDataManager(BaseDataManager):
self._num_train_cams = 0
for name in self.source_names:
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id)
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id, combineall=self.combineall)
for img_paths, pid, camid in dataset.train:
pid += self._num_train_pids
@ -251,7 +255,7 @@ class VideoDataManager(BaseDataManager):
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in target_names}
for name in self.target_names:
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id)
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id, combineall=self.combineall,)
self.testloader_dict[name]['query'] = DataLoader(
VideoDataset(

View File

@ -4,6 +4,7 @@ from __future__ import print_function
import os
import os.path as osp
import numpy as np
import copy
class BaseDataset(object):
@ -32,13 +33,39 @@ class BaseDataset(object):
def get_num_cams(self, data):
return self.extract_data_info(data)[2]
def init_attributes(self, train, query, gallery):
def init_attributes(self, train, query, gallery, combineall=False, **kwargs):
"""Initialize class attributes
Args:
train (list): contains a list of (img_path, pid, camid)
query (list): contains a list of (img_path, pid, camid)
gallery (list): contains a list of (img_path, pid, camid)
combineall (bool): if set to True, combine all data for training, default is False
Notes:
This method has to be called (at the end) in each dataset class.
"""
self._train = train
self._query = query
self._gallery = gallery
self._num_train_pids = self.get_num_pids(train)
self._num_train_cams = self.get_num_cams(train)
if combineall:
self._train = self.combine_all(train, query, gallery)
self._num_train_pids = self.get_num_pids(self.train)
def combine_all(self, train, query, gallery):
"""Combine all data for training
Notes:
1. In general, we assume that all query identities appear in gallery set.
2. All pids in train have been relabeled (starts from 0)
3. pid=0 (background) and pid=-1 (junk) are discarded.
4. Camera views remain the same across train, query and gallery.
"""
raise NotImplementedError
@property
def train(self):
# train list containing (img_path, pid, camid)
@ -72,17 +99,39 @@ class BaseImageDataset(BaseDataset):
"""Base class of image-reid dataset"""
def extract_data_info(self, data):
pids, cams = [], []
pids = set()
cams = set()
for _, pid, camid in data:
pids += [pid]
cams += [camid]
pids = set(pids)
cams = set(cams)
pids.add(pid)
cams.add(camid)
num_pids = len(pids)
num_cams = len(cams)
num_imgs = len(data)
return num_pids, num_imgs, num_cams
def combine_all(self, train, query, gallery):
combined = copy.deepcopy(train)
# relabel pids in gallery
g_pids = set()
for _, pid, _ in gallery:
if pid==0 or pid==-1:
continue
g_pids.add(pid)
pid2label = {pid: i for i, pid in enumerate(g_pids)}
def _combine_data(data):
for img_path, pid, camid in data:
if pid==0 or pid==-1:
continue
pid = pid2label[pid] + self.num_train_pids
combined.append((img_path, pid, camid))
_combine_data(query)
_combine_data(gallery)
return combined
def print_dataset_statistics(self, train, query, gallery):
num_train_pids, num_train_imgs, num_train_cams = self.extract_data_info(train)
num_query_pids, num_query_imgs, num_query_cams = self.extract_data_info(query)
@ -102,13 +151,13 @@ class BaseVideoDataset(BaseDataset):
"""Base class of video-reid dataset"""
def extract_data_info(self, data, return_tracklet_stats=False):
pids, cams, tracklet_stats = [], [], []
pids = set()
cams = set()
tracklet_stats = []
for img_paths, pid, camid in data:
pids += [pid]
cams += [camid]
pids.add(pid)
cams.add(camid)
tracklet_stats += [len(img_paths)]
pids = set(pids)
cams = set(cams)
num_pids = len(pids)
num_cams = len(cams)
num_tracklets = len(data)

View File

@ -63,10 +63,10 @@ class CUHK01(BaseImageDataset):
query = [tuple(item) for item in query]
gallery = [tuple(item) for item in gallery]
self.init_attributes(train, query, gallery)
self.init_attributes(train, query, gallery, **kwargs)
if verbose:
self.print_dataset_statistics(train, query, gallery)
self.print_dataset_statistics(self.train, self.query, self.gallery)
def extract_file(self):
if not osp.exists(self.campus_dir):

View File

@ -81,10 +81,10 @@ class CUHK03(BaseImageDataset):
query = split['query']
gallery = split['gallery']
self.init_attributes(train, query, gallery)
self.init_attributes(train, query, gallery, **kwargs)
if verbose:
self.print_dataset_statistics(train, query, gallery)
self.print_dataset_statistics(self.train, self.query, self.gallery)
def preprocess_split(self):
"""

View File

@ -57,10 +57,10 @@ class DukeMTMCreID(BaseImageDataset):
query = self.process_dir(self.query_dir, relabel=False)
gallery = self.process_dir(self.gallery_dir, relabel=False)
self.init_attributes(train, query, gallery)
self.init_attributes(train, query, gallery, **kwargs)
if verbose:
self.print_dataset_statistics(train, query, gallery)
self.print_dataset_statistics(self.train, self.query, self.gallery)
def download_data(self):
if osp.exists(self.dataset_dir):

View File

@ -60,10 +60,10 @@ class DukeMTMCVidReID(BaseVideoDataset):
query = self.process_dir(self.query_dir, self.split_query_json_path, relabel=False)
gallery = self.process_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False)
self.init_attributes(train, query, gallery)
self.init_attributes(train, query, gallery, **kwargs)
if verbose:
self.print_dataset_statistics(train, query, gallery)
self.print_dataset_statistics(self.train, self.query, self.gallery)
def download_data(self):
if osp.exists(self.dataset_dir):

View File

@ -67,10 +67,10 @@ class GRID(BaseImageDataset):
query = [tuple(item) for item in query]
gallery = [tuple(item) for item in gallery]
self.init_attributes(train, query, gallery)
self.init_attributes(train, query, gallery, **kwargs)
if verbose:
self.print_dataset_statistics(train, query, gallery)
self.print_dataset_statistics(self.train, self.query, self.gallery)
def download_data(self):
if osp.exists(self.dataset_dir):

View File

@ -58,10 +58,10 @@ class iLIDS(BaseImageDataset):
train, query, gallery = self.process_split(split)
self.init_attributes(train, query, gallery)
self.init_attributes(train, query, gallery, **kwargs)
if verbose:
self.print_dataset_statistics(train, query, gallery)
self.print_dataset_statistics(self.train, self.query, self.gallery)
def download_data(self):
if osp.exists(self.dataset_dir):

View File

@ -65,10 +65,10 @@ class iLIDSVID(BaseVideoDataset):
query = self.process_data(test_dirs, cam1=True, cam2=False)
gallery = self.process_data(test_dirs, cam1=False, cam2=True)
self.init_attributes(train, query, gallery)
self.init_attributes(train, query, gallery, **kwargs)
if verbose:
self.print_dataset_statistics(train, query, gallery)
self.print_dataset_statistics(self.train, self.query, self.gallery)
def download_data(self):
if osp.exists(self.dataset_dir):

View File

@ -57,10 +57,10 @@ class Market1501(BaseImageDataset):
if self.market1501_500k:
gallery += self.process_dir(self.extra_gallery_dir, relabel=False)
self.init_attributes(train, query, gallery)
self.init_attributes(train, query, gallery, **kwargs)
if verbose:
self.print_dataset_statistics(train, query, gallery)
self.print_dataset_statistics(self.train, self.query, self.gallery)
def process_dir(self, dir_path, relabel=False):
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))

View File

@ -66,10 +66,10 @@ class Mars(BaseVideoDataset):
query = self.process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len)
gallery = self.process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len)
self.init_attributes(train, query, gallery)
self.init_attributes(train, query, gallery, **kwargs)
if verbose:
self.print_dataset_statistics(train, query, gallery)
self.print_dataset_statistics(self.train, self.query, self.gallery)
def get_names(self, fpath):
names = []

View File

@ -80,14 +80,15 @@ class MSMT17(BaseImageDataset):
query = self.process_dir(self.test_dir, self.list_query_path)
gallery = self.process_dir(self.test_dir, self.list_gallery_path)
# To fairly compare with published methods, don't use val images for training
#train += val
#num_train_imgs += num_val_imgs
# Note: to fairly compare with published methods on the conventional ReID setting,
# do not add val images to the training set.
if 'combineall' in kwargs and kwargs['combineall']:
train += val
self.init_attributes(train, query, gallery)
self.init_attributes(train, query, gallery, **kwargs)
if verbose:
self.print_dataset_statistics(train, query, gallery)
self.print_dataset_statistics(self.train, self.query, self.gallery)
def process_dir(self, dir_path, list_path):
with open(list_path, 'r') as txt:
lines = txt.readlines()

View File

@ -60,10 +60,10 @@ class PRID(BaseImageDataset):
train, query, gallery = self.process_split(split)
self.init_attributes(train, query, gallery)
self.init_attributes(train, query, gallery, **kwargs)
if verbose:
self.print_dataset_statistics(train, query, gallery)
self.print_dataset_statistics(self.train, self.query, self.gallery)
def prepare_split(self):
if not osp.exists(self.split_path):

View File

@ -58,10 +58,10 @@ class PRID2011(BaseVideoDataset):
query = self.process_dir(test_dirs, cam1=True, cam2=False)
gallery = self.process_dir(test_dirs, cam1=False, cam2=True)
self.init_attributes(train, query, gallery)
self.init_attributes(train, query, gallery, **kwargs)
if verbose:
self.print_dataset_statistics(train, query, gallery)
self.print_dataset_statistics(self.train, self.query, self.gallery)
def process_dir(self, dirnames, cam1=True, cam2=True):
tracklets = []

View File

@ -65,10 +65,10 @@ class PRID450S(BaseImageDataset):
query = [tuple(item) for item in query]
gallery = [tuple(item) for item in gallery]
self.init_attributes(train, query, gallery)
self.init_attributes(train, query, gallery, **kwargs)
if verbose:
self.print_dataset_statistics(train, query, gallery)
self.print_dataset_statistics(self.train, self.query, self.gallery)
def download_data(self):
if osp.exists(self.dataset_dir):

View File

@ -54,10 +54,10 @@ class SenseReID(BaseImageDataset):
gallery = self.process_dir(self.gallery_dir)
train = copy.deepcopy(query) # dummy variable
self.init_attributes(train, query, gallery)
self.init_attributes(train, query, gallery, **kwargs)
if verbose:
self.print_dataset_statistics(train, query, gallery)
self.print_dataset_statistics(self.train, self.query, self.gallery)
def process_dir(self, dir_path):
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))

View File

@ -64,10 +64,10 @@ class VIPeR(BaseImageDataset):
query = [tuple(item) for item in query]
gallery = [tuple(item) for item in gallery]
self.init_attributes(train, query, gallery)
self.init_attributes(train, query, gallery, **kwargs)
if verbose:
self.print_dataset_statistics(train, query, gallery)
self.print_dataset_statistics(self.train, self.query, self.gallery)
def download_data(self):
if osp.exists(self.dataset_dir):