v1.2.6: add RandomDomainSampler
parent
b075bd3f4b
commit
ad4d8d7c29
|
@ -42,6 +42,7 @@ def get_default_config():
|
|||
cfg.sampler.train_sampler = 'RandomSampler' # sampler for source train loader
|
||||
cfg.sampler.train_sampler_t = 'RandomSampler' # sampler for target train loader
|
||||
cfg.sampler.num_instances = 4 # number of instances per identity for RandomIdentitySampler
|
||||
cfg.sampler.num_cams = 1 # number of cameras to sample in a batch (for RandomDomainSampler)
|
||||
|
||||
# video reid setting
|
||||
cfg.video = CN()
|
||||
|
@ -126,6 +127,7 @@ def imagedata_kwargs(cfg):
|
|||
'batch_size_test': cfg.test.batch_size,
|
||||
'workers': cfg.data.workers,
|
||||
'num_instances': cfg.sampler.num_instances,
|
||||
'num_cams': cfg.sampler.num_cams,
|
||||
'train_sampler': cfg.sampler.train_sampler,
|
||||
'train_sampler_t': cfg.sampler.train_sampler_t,
|
||||
# image
|
||||
|
@ -152,6 +154,7 @@ def videodata_kwargs(cfg):
|
|||
'batch_size_test': cfg.test.batch_size,
|
||||
'workers': cfg.data.workers,
|
||||
'num_instances': cfg.sampler.num_instances,
|
||||
'num_cams': cfg.sampler.num_cams,
|
||||
'train_sampler': cfg.sampler.train_sampler,
|
||||
# video
|
||||
'seq_len': cfg.video.seq_len,
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import print_function, absolute_import
|
|||
|
||||
from torchreid import data, optim, utils, engine, losses, models, metrics
|
||||
|
||||
__version__ = '1.2.5'
|
||||
__version__ = '1.2.6'
|
||||
__author__ = 'Kaiyang Zhou'
|
||||
__homepage__ = 'https://kaiyangzhou.github.io/'
|
||||
__description__ = 'Deep learning person re-identification in PyTorch'
|
||||
|
|
|
@ -115,6 +115,8 @@ class ImageDataManager(DataManager):
|
|||
workers (int, optional): number of workers. Default is 4.
|
||||
num_instances (int, optional): number of instances per identity in a batch.
|
||||
Default is 4.
|
||||
num_cams (int, optional): number of cameras to sample in a batch (when using
|
||||
``RandomDomainSampler``). Default is 1.
|
||||
train_sampler (str, optional): sampler. Default is RandomSampler.
|
||||
train_sampler_t (str, optional): sampler for target train loader. Default is RandomSampler.
|
||||
cuhk03_labeled (bool, optional): use cuhk03 labeled images.
|
||||
|
@ -165,6 +167,7 @@ class ImageDataManager(DataManager):
|
|||
batch_size_test=32,
|
||||
workers=4,
|
||||
num_instances=4,
|
||||
num_cams=1,
|
||||
train_sampler='RandomSampler',
|
||||
train_sampler_t='RandomSampler',
|
||||
cuhk03_labeled=False,
|
||||
|
@ -210,7 +213,8 @@ class ImageDataManager(DataManager):
|
|||
trainset.train,
|
||||
train_sampler,
|
||||
batch_size=batch_size_train,
|
||||
num_instances=num_instances
|
||||
num_instances=num_instances,
|
||||
num_cams=num_cams
|
||||
),
|
||||
batch_size=batch_size_train,
|
||||
shuffle=False,
|
||||
|
@ -249,7 +253,8 @@ class ImageDataManager(DataManager):
|
|||
trainset_t.train,
|
||||
train_sampler_t,
|
||||
batch_size=batch_size_train,
|
||||
num_instances=num_instances
|
||||
num_instances=num_instances,
|
||||
num_cams=num_cams
|
||||
),
|
||||
batch_size=batch_size_train,
|
||||
shuffle=False,
|
||||
|
@ -360,6 +365,8 @@ class VideoDataManager(DataManager):
|
|||
workers (int, optional): number of workers. Default is 4.
|
||||
num_instances (int, optional): number of instances per identity in a batch.
|
||||
Default is 4.
|
||||
num_cams (int, optional): number of cameras to sample in a batch (when using
|
||||
``RandomDomainSampler``). Default is 1.
|
||||
train_sampler (str, optional): sampler. Default is RandomSampler.
|
||||
seq_len (int, optional): how many images to sample in a tracklet. Default is 15.
|
||||
sample_method (str, optional): how to sample images in a tracklet. Default is "evenly".
|
||||
|
@ -411,6 +418,7 @@ class VideoDataManager(DataManager):
|
|||
batch_size_test=3,
|
||||
workers=4,
|
||||
num_instances=4,
|
||||
num_cams=1,
|
||||
train_sampler='RandomSampler',
|
||||
seq_len=15,
|
||||
sample_method='evenly'
|
||||
|
@ -450,7 +458,8 @@ class VideoDataManager(DataManager):
|
|||
trainset.train,
|
||||
train_sampler,
|
||||
batch_size=batch_size_train,
|
||||
num_instances=num_instances
|
||||
num_instances=num_instances,
|
||||
num_cams=num_cams
|
||||
)
|
||||
|
||||
self.train_loader = torch.utils.data.DataLoader(
|
||||
|
|
|
@ -5,7 +5,10 @@ import random
|
|||
from collections import defaultdict
|
||||
from torch.utils.data.sampler import Sampler, RandomSampler, SequentialSampler
|
||||
|
||||
AVAI_SAMPLERS = ['RandomIdentitySampler', 'SequentialSampler', 'RandomSampler']
|
||||
AVAI_SAMPLERS = [
|
||||
'RandomIdentitySampler', 'SequentialSampler', 'RandomSampler',
|
||||
'RandomDomainSampler'
|
||||
]
|
||||
|
||||
|
||||
class RandomIdentitySampler(Sampler):
|
||||
|
@ -77,8 +80,68 @@ class RandomIdentitySampler(Sampler):
|
|||
return self.length
|
||||
|
||||
|
||||
class RandomDomainSampler(Sampler):
|
||||
"""Random domain sampler.
|
||||
|
||||
Each camera is considered as a distinct domain.
|
||||
|
||||
1. Randomly sample N cameras.
|
||||
2. From each camera, randomly sample K images.
|
||||
"""
|
||||
|
||||
def __init__(self, data_source, batch_size, n_domain):
|
||||
self.data_source = data_source
|
||||
|
||||
# Keep track of image indices for each domain
|
||||
self.domain_dict = defaultdict(list)
|
||||
for i, (_, _, camid) in enumerate(data_source):
|
||||
self.domain_dict[camid].append(i)
|
||||
self.domains = list(self.domain_dict.keys())
|
||||
|
||||
# Make sure each domain has equal number of images
|
||||
if n_domain is None or n_domain <= 0:
|
||||
n_domain = len(self.domains)
|
||||
assert batch_size % n_domain == 0
|
||||
self.n_img_per_domain = batch_size // n_domain
|
||||
|
||||
self.batch_size = batch_size
|
||||
# n_domain denotes number of domains sampled in a minibatch
|
||||
self.n_domain = n_domain
|
||||
self.length = len(list(self.__iter__()))
|
||||
|
||||
def __iter__(self):
|
||||
domain_dict = copy.deepcopy(self.domain_dict)
|
||||
final_idxs = []
|
||||
stop_sampling = False
|
||||
|
||||
while not stop_sampling:
|
||||
selected_domains = random.sample(self.domains, self.n_domain)
|
||||
|
||||
for domain in selected_domains:
|
||||
idxs = domain_dict[domain]
|
||||
selected_idxs = random.sample(idxs, self.n_img_per_domain)
|
||||
final_idxs.extend(selected_idxs)
|
||||
|
||||
for idx in selected_idxs:
|
||||
domain_dict[domain].remove(idx)
|
||||
|
||||
remaining = len(domain_dict[domain])
|
||||
if remaining < self.n_img_per_domain:
|
||||
stop_sampling = True
|
||||
|
||||
return iter(final_idxs)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
def build_train_sampler(
|
||||
data_source, train_sampler, batch_size=32, num_instances=4, **kwargs
|
||||
data_source,
|
||||
train_sampler,
|
||||
batch_size=32,
|
||||
num_instances=4,
|
||||
num_cams=1,
|
||||
**kwargs
|
||||
):
|
||||
"""Builds a training sampler.
|
||||
|
||||
|
@ -88,6 +151,8 @@ def build_train_sampler(
|
|||
batch_size (int, optional): batch size. Default is 32.
|
||||
num_instances (int, optional): number of instances per identity in a
|
||||
batch (when using ``RandomIdentitySampler``). Default is 4.
|
||||
num_cams (int, optional): number of cameras to sample in a batch (when using
|
||||
``RandomDomainSampler``). Default is 1.
|
||||
"""
|
||||
assert train_sampler in AVAI_SAMPLERS, \
|
||||
'train_sampler must be one of {}, but got {}'.format(AVAI_SAMPLERS, train_sampler)
|
||||
|
@ -95,6 +160,9 @@ def build_train_sampler(
|
|||
if train_sampler == 'RandomIdentitySampler':
|
||||
sampler = RandomIdentitySampler(data_source, batch_size, num_instances)
|
||||
|
||||
elif train_sampler == 'RandomDomainSampler':
|
||||
sampler = RandomDomainSampler(data_source, batch_size, num_cams)
|
||||
|
||||
elif train_sampler == 'SequentialSampler':
|
||||
sampler = SequentialSampler(data_source)
|
||||
|
||||
|
|
Loading…
Reference in New Issue