add RandomDatasetSampler

pull/405/head
KaiyangZhou 2020-08-10 12:31:11 +01:00
parent 52b334a63e
commit 4c7c43d34e
3 changed files with 97 additions and 13 deletions

View File

@ -43,6 +43,7 @@ def get_default_config():
cfg.sampler.train_sampler_t = 'RandomSampler' # sampler for target train loader
cfg.sampler.num_instances = 4 # number of instances per identity for RandomIdentitySampler
cfg.sampler.num_cams = 1 # number of cameras to sample in a batch (for RandomDomainSampler)
cfg.sampler.num_datasets = 1 # number of datasets to sample in a batch (for RandomDatasetSampler)
# video reid setting
cfg.video = CN()
@ -128,9 +129,10 @@ def imagedata_kwargs(cfg):
'workers': cfg.data.workers,
'num_instances': cfg.sampler.num_instances,
'num_cams': cfg.sampler.num_cams,
'num_datasets': cfg.sampler.num_datasets,
'train_sampler': cfg.sampler.train_sampler,
'train_sampler_t': cfg.sampler.train_sampler_t,
# image
# image dataset specific
'cuhk03_labeled': cfg.cuhk03.labeled_images,
'cuhk03_classic_split': cfg.cuhk03.classic_split,
'market1501_500k': cfg.market1501.use_500k_distractors,
@ -155,8 +157,9 @@ def videodata_kwargs(cfg):
'workers': cfg.data.workers,
'num_instances': cfg.sampler.num_instances,
'num_cams': cfg.sampler.num_cams,
'num_datasets': cfg.sampler.num_datasets,
'train_sampler': cfg.sampler.train_sampler,
# video
# video dataset specific
'seq_len': cfg.video.seq_len,
'sample_method': cfg.video.sample_method
}

View File

@ -117,6 +117,8 @@ class ImageDataManager(DataManager):
Default is 4.
num_cams (int, optional): number of cameras to sample in a batch (when using
``RandomDomainSampler``). Default is 1.
num_datasets (int, optional): number of datasets to sample in a batch (when
using ``RandomDatasetSampler``). Default is 1.
train_sampler (str, optional): sampler. Default is RandomSampler.
train_sampler_t (str, optional): sampler for target train loader. Default is RandomSampler.
cuhk03_labeled (bool, optional): use cuhk03 labeled images.
@ -168,6 +170,7 @@ class ImageDataManager(DataManager):
workers=4,
num_instances=4,
num_cams=1,
num_datasets=1,
train_sampler='RandomSampler',
train_sampler_t='RandomSampler',
cuhk03_labeled=False,
@ -214,7 +217,8 @@ class ImageDataManager(DataManager):
train_sampler,
batch_size=batch_size_train,
num_instances=num_instances,
num_cams=num_cams
num_cams=num_cams,
num_datasets=num_datasets
),
batch_size=batch_size_train,
shuffle=False,
@ -254,7 +258,8 @@ class ImageDataManager(DataManager):
train_sampler_t,
batch_size=batch_size_train,
num_instances=num_instances,
num_cams=num_cams
num_cams=num_cams,
num_datasets=num_datasets
),
batch_size=batch_size_train,
shuffle=False,
@ -367,6 +372,8 @@ class VideoDataManager(DataManager):
Default is 4.
num_cams (int, optional): number of cameras to sample in a batch (when using
``RandomDomainSampler``). Default is 1.
num_datasets (int, optional): number of datasets to sample in a batch (when
using ``RandomDatasetSampler``). Default is 1.
train_sampler (str, optional): sampler. Default is RandomSampler.
seq_len (int, optional): how many images to sample in a tracklet. Default is 15.
sample_method (str, optional): how to sample images in a tracklet. Default is "evenly".
@ -419,6 +426,7 @@ class VideoDataManager(DataManager):
workers=4,
num_instances=4,
num_cams=1,
num_datasets=1,
train_sampler='RandomSampler',
seq_len=15,
sample_method='evenly'
@ -459,7 +467,8 @@ class VideoDataManager(DataManager):
train_sampler,
batch_size=batch_size_train,
num_instances=num_instances,
num_cams=num_cams
num_cams=num_cams,
num_datasets=num_datasets
)
self.train_loader = torch.utils.data.DataLoader(

View File

@ -7,7 +7,7 @@ from torch.utils.data.sampler import Sampler, RandomSampler, SequentialSampler
AVAI_SAMPLERS = [
'RandomIdentitySampler', 'SequentialSampler', 'RandomSampler',
'RandomDomainSampler'
'RandomDomainSampler', 'RandomDatasetSampler'
]
@ -15,7 +15,7 @@ class RandomIdentitySampler(Sampler):
"""Randomly samples N identities each with K instances.
Args:
data_source (list): contains tuples of (img_path(s), pid, camid).
data_source (list): contains tuples of (img_path(s), pid, camid, dsetid).
batch_size (int): batch size.
num_instances (int): number of instances per identity in a batch.
"""
@ -32,7 +32,8 @@ class RandomIdentitySampler(Sampler):
self.num_instances = num_instances
self.num_pids_per_batch = self.batch_size // self.num_instances
self.index_dic = defaultdict(list)
for index, (_, pid, _) in enumerate(self.data_source):
for index, items in enumerate(data_source):
pid = items[1]
self.index_dic[pid].append(index)
self.pids = list(self.index_dic.keys())
@ -83,10 +84,16 @@ class RandomIdentitySampler(Sampler):
class RandomDomainSampler(Sampler):
"""Random domain sampler.
Each camera is considered as a distinct domain.
We consider each camera as a visual domain.
1. Randomly sample N cameras.
How does the sampling work:
1. Randomly sample N cameras (based on the "camid" label).
2. From each camera, randomly sample K images.
Args:
data_source (list): contains tuples of (img_path(s), pid, camid, dsetid).
batch_size (int): batch size.
n_domain (int): number of cameras to sample in a batch.
"""
def __init__(self, data_source, batch_size, n_domain):
@ -94,18 +101,18 @@ class RandomDomainSampler(Sampler):
# Keep track of image indices for each domain
self.domain_dict = defaultdict(list)
for i, (_, _, camid) in enumerate(data_source):
for i, items in enumerate(data_source):
camid = items[2]
self.domain_dict[camid].append(i)
self.domains = list(self.domain_dict.keys())
# Make sure each domain has equal number of images
# Make sure each domain can be assigned an equal number of images
if n_domain is None or n_domain <= 0:
n_domain = len(self.domains)
assert batch_size % n_domain == 0
self.n_img_per_domain = batch_size // n_domain
self.batch_size = batch_size
# n_domain denotes number of domains sampled in a minibatch
self.n_domain = n_domain
self.length = len(list(self.__iter__()))
@ -135,12 +142,72 @@ class RandomDomainSampler(Sampler):
return self.length
class RandomDatasetSampler(Sampler):
"""Random dataset sampler.
How does the sampling work:
1. Randomly sample N datasets (based on the "dsetid" label).
2. From each dataset, randomly sample K images.
Args:
data_source (list): contains tuples of (img_path(s), pid, camid, dsetid).
batch_size (int): batch size.
n_dataset (int): number of datasets to sample in a batch.
"""
def __init__(self, data_source, batch_size, n_dataset):
self.data_source = data_source
# Keep track of image indices for each dataset
self.dataset_dict = defaultdict(list)
for i, items in enumerate(data_source):
dsetid = items[3]
self.dataset_dict[dsetid].append(i)
self.datasets = list(self.dataset_dict.keys())
# Make sure each dataset can be assigned an equal number of images
if n_dataset is None or n_dataset <= 0:
n_dataset = len(self.datasets)
assert batch_size % n_dataset == 0
self.n_img_per_dset = batch_size // n_dataset
self.batch_size = batch_size
self.n_dataset = n_dataset
self.length = len(list(self.__iter__()))
def __iter__(self):
dataset_dict = copy.deepcopy(self.dataset_dict)
final_idxs = []
stop_sampling = False
while not stop_sampling:
selected_datasets = random.sample(self.datasets, self.n_dataset)
for dset in selected_datasets:
idxs = dataset_dict[dset]
selected_idxs = random.sample(idxs, self.n_img_per_dset)
final_idxs.extend(selected_idxs)
for idx in selected_idxs:
dataset_dict[dset].remove(idx)
remaining = len(dataset_dict[dset])
if remaining < self.n_img_per_dset:
stop_sampling = True
return iter(final_idxs)
def __len__(self):
return self.length
def build_train_sampler(
data_source,
train_sampler,
batch_size=32,
num_instances=4,
num_cams=1,
num_datasets=1,
**kwargs
):
"""Builds a training sampler.
@ -153,6 +220,8 @@ def build_train_sampler(
batch (when using ``RandomIdentitySampler``). Default is 4.
num_cams (int, optional): number of cameras to sample in a batch (when using
``RandomDomainSampler``). Default is 1.
num_datasets (int, optional): number of datasets to sample in a batch (when
using ``RandomDatasetSampler``). Default is 1.
"""
assert train_sampler in AVAI_SAMPLERS, \
'train_sampler must be one of {}, but got {}'.format(AVAI_SAMPLERS, train_sampler)
@ -163,6 +232,9 @@ def build_train_sampler(
elif train_sampler == 'RandomDomainSampler':
sampler = RandomDomainSampler(data_source, batch_size, num_cams)
elif train_sampler == 'RandomDatasetSampler':
sampler = RandomDatasetSampler(data_source, batch_size, num_datasets)
elif train_sampler == 'SequentialSampler':
sampler = SequentialSampler(data_source)