add RandomDatasetSampler
parent
52b334a63e
commit
4c7c43d34e
|
@ -43,6 +43,7 @@ def get_default_config():
|
|||
cfg.sampler.train_sampler_t = 'RandomSampler' # sampler for target train loader
|
||||
cfg.sampler.num_instances = 4 # number of instances per identity for RandomIdentitySampler
|
||||
cfg.sampler.num_cams = 1 # number of cameras to sample in a batch (for RandomDomainSampler)
|
||||
cfg.sampler.num_datasets = 1 # number of datasets to sample in a batch (for RandomDatasetSampler)
|
||||
|
||||
# video reid setting
|
||||
cfg.video = CN()
|
||||
|
@ -128,9 +129,10 @@ def imagedata_kwargs(cfg):
|
|||
'workers': cfg.data.workers,
|
||||
'num_instances': cfg.sampler.num_instances,
|
||||
'num_cams': cfg.sampler.num_cams,
|
||||
'num_datasets': cfg.sampler.num_datasets,
|
||||
'train_sampler': cfg.sampler.train_sampler,
|
||||
'train_sampler_t': cfg.sampler.train_sampler_t,
|
||||
# image
|
||||
# image dataset specific
|
||||
'cuhk03_labeled': cfg.cuhk03.labeled_images,
|
||||
'cuhk03_classic_split': cfg.cuhk03.classic_split,
|
||||
'market1501_500k': cfg.market1501.use_500k_distractors,
|
||||
|
@ -155,8 +157,9 @@ def videodata_kwargs(cfg):
|
|||
'workers': cfg.data.workers,
|
||||
'num_instances': cfg.sampler.num_instances,
|
||||
'num_cams': cfg.sampler.num_cams,
|
||||
'num_datasets': cfg.sampler.num_datasets,
|
||||
'train_sampler': cfg.sampler.train_sampler,
|
||||
# video
|
||||
# video dataset specific
|
||||
'seq_len': cfg.video.seq_len,
|
||||
'sample_method': cfg.video.sample_method
|
||||
}
|
||||
|
|
|
@ -117,6 +117,8 @@ class ImageDataManager(DataManager):
|
|||
Default is 4.
|
||||
num_cams (int, optional): number of cameras to sample in a batch (when using
|
||||
``RandomDomainSampler``). Default is 1.
|
||||
num_datasets (int, optional): number of datasets to sample in a batch (when
|
||||
using ``RandomDatasetSampler``). Default is 1.
|
||||
train_sampler (str, optional): sampler. Default is RandomSampler.
|
||||
train_sampler_t (str, optional): sampler for target train loader. Default is RandomSampler.
|
||||
cuhk03_labeled (bool, optional): use cuhk03 labeled images.
|
||||
|
@ -168,6 +170,7 @@ class ImageDataManager(DataManager):
|
|||
workers=4,
|
||||
num_instances=4,
|
||||
num_cams=1,
|
||||
num_datasets=1,
|
||||
train_sampler='RandomSampler',
|
||||
train_sampler_t='RandomSampler',
|
||||
cuhk03_labeled=False,
|
||||
|
@ -214,7 +217,8 @@ class ImageDataManager(DataManager):
|
|||
train_sampler,
|
||||
batch_size=batch_size_train,
|
||||
num_instances=num_instances,
|
||||
num_cams=num_cams
|
||||
num_cams=num_cams,
|
||||
num_datasets=num_datasets
|
||||
),
|
||||
batch_size=batch_size_train,
|
||||
shuffle=False,
|
||||
|
@ -254,7 +258,8 @@ class ImageDataManager(DataManager):
|
|||
train_sampler_t,
|
||||
batch_size=batch_size_train,
|
||||
num_instances=num_instances,
|
||||
num_cams=num_cams
|
||||
num_cams=num_cams,
|
||||
num_datasets=num_datasets
|
||||
),
|
||||
batch_size=batch_size_train,
|
||||
shuffle=False,
|
||||
|
@ -367,6 +372,8 @@ class VideoDataManager(DataManager):
|
|||
Default is 4.
|
||||
num_cams (int, optional): number of cameras to sample in a batch (when using
|
||||
``RandomDomainSampler``). Default is 1.
|
||||
num_datasets (int, optional): number of datasets to sample in a batch (when
|
||||
using ``RandomDatasetSampler``). Default is 1.
|
||||
train_sampler (str, optional): sampler. Default is RandomSampler.
|
||||
seq_len (int, optional): how many images to sample in a tracklet. Default is 15.
|
||||
sample_method (str, optional): how to sample images in a tracklet. Default is "evenly".
|
||||
|
@ -419,6 +426,7 @@ class VideoDataManager(DataManager):
|
|||
workers=4,
|
||||
num_instances=4,
|
||||
num_cams=1,
|
||||
num_datasets=1,
|
||||
train_sampler='RandomSampler',
|
||||
seq_len=15,
|
||||
sample_method='evenly'
|
||||
|
@ -459,7 +467,8 @@ class VideoDataManager(DataManager):
|
|||
train_sampler,
|
||||
batch_size=batch_size_train,
|
||||
num_instances=num_instances,
|
||||
num_cams=num_cams
|
||||
num_cams=num_cams,
|
||||
num_datasets=num_datasets
|
||||
)
|
||||
|
||||
self.train_loader = torch.utils.data.DataLoader(
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch.utils.data.sampler import Sampler, RandomSampler, SequentialSampler
|
|||
|
||||
AVAI_SAMPLERS = [
|
||||
'RandomIdentitySampler', 'SequentialSampler', 'RandomSampler',
|
||||
'RandomDomainSampler'
|
||||
'RandomDomainSampler', 'RandomDatasetSampler'
|
||||
]
|
||||
|
||||
|
||||
|
@ -15,7 +15,7 @@ class RandomIdentitySampler(Sampler):
|
|||
"""Randomly samples N identities each with K instances.
|
||||
|
||||
Args:
|
||||
data_source (list): contains tuples of (img_path(s), pid, camid).
|
||||
data_source (list): contains tuples of (img_path(s), pid, camid, dsetid).
|
||||
batch_size (int): batch size.
|
||||
num_instances (int): number of instances per identity in a batch.
|
||||
"""
|
||||
|
@ -32,7 +32,8 @@ class RandomIdentitySampler(Sampler):
|
|||
self.num_instances = num_instances
|
||||
self.num_pids_per_batch = self.batch_size // self.num_instances
|
||||
self.index_dic = defaultdict(list)
|
||||
for index, (_, pid, _) in enumerate(self.data_source):
|
||||
for index, items in enumerate(data_source):
|
||||
pid = items[1]
|
||||
self.index_dic[pid].append(index)
|
||||
self.pids = list(self.index_dic.keys())
|
||||
|
||||
|
@ -83,10 +84,16 @@ class RandomIdentitySampler(Sampler):
|
|||
class RandomDomainSampler(Sampler):
|
||||
"""Random domain sampler.
|
||||
|
||||
Each camera is considered as a distinct domain.
|
||||
We consider each camera as a visual domain.
|
||||
|
||||
1. Randomly sample N cameras.
|
||||
How does the sampling work:
|
||||
1. Randomly sample N cameras (based on the "camid" label).
|
||||
2. From each camera, randomly sample K images.
|
||||
|
||||
Args:
|
||||
data_source (list): contains tuples of (img_path(s), pid, camid, dsetid).
|
||||
batch_size (int): batch size.
|
||||
n_domain (int): number of cameras to sample in a batch.
|
||||
"""
|
||||
|
||||
def __init__(self, data_source, batch_size, n_domain):
|
||||
|
@ -94,18 +101,18 @@ class RandomDomainSampler(Sampler):
|
|||
|
||||
# Keep track of image indices for each domain
|
||||
self.domain_dict = defaultdict(list)
|
||||
for i, (_, _, camid) in enumerate(data_source):
|
||||
for i, items in enumerate(data_source):
|
||||
camid = items[2]
|
||||
self.domain_dict[camid].append(i)
|
||||
self.domains = list(self.domain_dict.keys())
|
||||
|
||||
# Make sure each domain has equal number of images
|
||||
# Make sure each domain can be assigned an equal number of images
|
||||
if n_domain is None or n_domain <= 0:
|
||||
n_domain = len(self.domains)
|
||||
assert batch_size % n_domain == 0
|
||||
self.n_img_per_domain = batch_size // n_domain
|
||||
|
||||
self.batch_size = batch_size
|
||||
# n_domain denotes number of domains sampled in a minibatch
|
||||
self.n_domain = n_domain
|
||||
self.length = len(list(self.__iter__()))
|
||||
|
||||
|
@ -135,12 +142,72 @@ class RandomDomainSampler(Sampler):
|
|||
return self.length
|
||||
|
||||
|
||||
class RandomDatasetSampler(Sampler):
|
||||
"""Random dataset sampler.
|
||||
|
||||
How does the sampling work:
|
||||
1. Randomly sample N datasets (based on the "dsetid" label).
|
||||
2. From each dataset, randomly sample K images.
|
||||
|
||||
Args:
|
||||
data_source (list): contains tuples of (img_path(s), pid, camid, dsetid).
|
||||
batch_size (int): batch size.
|
||||
n_dataset (int): number of datasets to sample in a batch.
|
||||
"""
|
||||
|
||||
def __init__(self, data_source, batch_size, n_dataset):
|
||||
self.data_source = data_source
|
||||
|
||||
# Keep track of image indices for each dataset
|
||||
self.dataset_dict = defaultdict(list)
|
||||
for i, items in enumerate(data_source):
|
||||
dsetid = items[3]
|
||||
self.dataset_dict[dsetid].append(i)
|
||||
self.datasets = list(self.dataset_dict.keys())
|
||||
|
||||
# Make sure each dataset can be assigned an equal number of images
|
||||
if n_dataset is None or n_dataset <= 0:
|
||||
n_dataset = len(self.datasets)
|
||||
assert batch_size % n_dataset == 0
|
||||
self.n_img_per_dset = batch_size // n_dataset
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.n_dataset = n_dataset
|
||||
self.length = len(list(self.__iter__()))
|
||||
|
||||
def __iter__(self):
|
||||
dataset_dict = copy.deepcopy(self.dataset_dict)
|
||||
final_idxs = []
|
||||
stop_sampling = False
|
||||
|
||||
while not stop_sampling:
|
||||
selected_datasets = random.sample(self.datasets, self.n_dataset)
|
||||
|
||||
for dset in selected_datasets:
|
||||
idxs = dataset_dict[dset]
|
||||
selected_idxs = random.sample(idxs, self.n_img_per_dset)
|
||||
final_idxs.extend(selected_idxs)
|
||||
|
||||
for idx in selected_idxs:
|
||||
dataset_dict[dset].remove(idx)
|
||||
|
||||
remaining = len(dataset_dict[dset])
|
||||
if remaining < self.n_img_per_dset:
|
||||
stop_sampling = True
|
||||
|
||||
return iter(final_idxs)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
def build_train_sampler(
|
||||
data_source,
|
||||
train_sampler,
|
||||
batch_size=32,
|
||||
num_instances=4,
|
||||
num_cams=1,
|
||||
num_datasets=1,
|
||||
**kwargs
|
||||
):
|
||||
"""Builds a training sampler.
|
||||
|
@ -153,6 +220,8 @@ def build_train_sampler(
|
|||
batch (when using ``RandomIdentitySampler``). Default is 4.
|
||||
num_cams (int, optional): number of cameras to sample in a batch (when using
|
||||
``RandomDomainSampler``). Default is 1.
|
||||
num_datasets (int, optional): number of datasets to sample in a batch (when
|
||||
using ``RandomDatasetSampler``). Default is 1.
|
||||
"""
|
||||
assert train_sampler in AVAI_SAMPLERS, \
|
||||
'train_sampler must be one of {}, but got {}'.format(AVAI_SAMPLERS, train_sampler)
|
||||
|
@ -163,6 +232,9 @@ def build_train_sampler(
|
|||
elif train_sampler == 'RandomDomainSampler':
|
||||
sampler = RandomDomainSampler(data_source, batch_size, num_cams)
|
||||
|
||||
elif train_sampler == 'RandomDatasetSampler':
|
||||
sampler = RandomDatasetSampler(data_source, batch_size, num_datasets)
|
||||
|
||||
elif train_sampler == 'SequentialSampler':
|
||||
sampler = SequentialSampler(data_source)
|
||||
|
||||
|
|
Loading…
Reference in New Issue