[1.0.8] load_unlabeled_targets -> load_train_targets
parent
d6405de241
commit
2095ee05c5
|
@ -38,7 +38,7 @@ You can find some research projects that are built on top of Torchreid `here <ht
|
|||
|
||||
What's new
|
||||
---------------
|
||||
- [Nov 19] ``ImageDataManager`` can load training data from target datasets by setting ``load_unlabeled_targets=True``, and the train-loader can be accessed with ``train_loader_u = datamanager.train_loader_u``. This feature is useful for domain adaptation research.
|
||||
- [Nov 19] ``ImageDataManager`` can load training data from target datasets by setting ``load_train_targets=True``, and the train-loader can be accessed with ``train_loader_t = datamanager.train_loader_t``. This feature is useful for domain adaptation research.
|
||||
|
||||
|
||||
Installation
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
|
||||
__version__ = '1.0.7'
|
||||
__version__ = '1.0.8'
|
||||
__author__ = 'Kaiyang Zhou'
|
||||
__homepage__ = 'https://kaiyangzhou.github.io/'
|
||||
__description__ = 'Deep learning person re-identification in PyTorch'
|
||||
|
|
|
@ -93,8 +93,8 @@ class ImageDataManager(DataManager):
|
|||
split_id (int, optional): split id (*0-based*). Default is 0.
|
||||
combineall (bool, optional): combine train, query and gallery in a dataset for
|
||||
training. Default is False.
|
||||
load_unlabeled_targets (bool, optional): construct train loader for unlabeled target
|
||||
datasets. Default is False.
|
||||
load_train_targets (bool, optional): construct train-loader for target datasets.
|
||||
Default is False. This is useful for domain adaptation research.
|
||||
batch_size_train (int, optional): number of images in a training batch. Default is 32.
|
||||
batch_size_test (int, optional): number of images in a test batch. Default is 32.
|
||||
workers (int, optional): number of workers. Default is 4.
|
||||
|
@ -126,7 +126,7 @@ class ImageDataManager(DataManager):
|
|||
test_loader = datamanager.test_loader
|
||||
|
||||
# return train loader of target data
|
||||
train_loader_u = datamanager.train_loader_u
|
||||
train_loader_t = datamanager.train_loader_t
|
||||
"""
|
||||
data_type = 'image'
|
||||
|
||||
|
@ -143,7 +143,7 @@ class ImageDataManager(DataManager):
|
|||
use_gpu=True,
|
||||
split_id=0,
|
||||
combineall=False,
|
||||
load_unlabeled_targets=False,
|
||||
load_train_targets=False,
|
||||
batch_size_train=32,
|
||||
batch_size_test=32,
|
||||
workers=4,
|
||||
|
@ -194,33 +194,33 @@ class ImageDataManager(DataManager):
|
|||
drop_last=True
|
||||
)
|
||||
|
||||
self.train_loader_u = None
|
||||
if load_unlabeled_targets:
|
||||
self.train_loader_t = None
|
||||
if load_train_targets:
|
||||
# check if sources and targets are identical
|
||||
assert len(set(self.sources) & set(self.targets)) == 0, \
|
||||
'sources={} and targets={} must not have overlap'.format(self.sources, self.targets)
|
||||
|
||||
print('=> Loading train (target) dataset')
|
||||
trainset_u = []
|
||||
trainset_t = []
|
||||
for name in self.targets:
|
||||
trainset_u_ = init_image_dataset(
|
||||
trainset_t_ = init_image_dataset(
|
||||
name,
|
||||
transform=self.transform_tr,
|
||||
mode='train',
|
||||
combineall=combineall,
|
||||
combineall=False, # only use the training data
|
||||
root=root,
|
||||
split_id=split_id,
|
||||
cuhk03_labeled=cuhk03_labeled,
|
||||
cuhk03_classic_split=cuhk03_classic_split,
|
||||
market1501_500k=market1501_500k
|
||||
)
|
||||
trainset_u.append(trainset_u_)
|
||||
trainset_u = sum(trainset_u)
|
||||
trainset_t.append(trainset_t_)
|
||||
trainset_t = sum(trainset_t)
|
||||
|
||||
self.train_loader_u = torch.utils.data.DataLoader(
|
||||
trainset_u,
|
||||
self.train_loader_t = torch.utils.data.DataLoader(
|
||||
trainset_t,
|
||||
sampler=build_train_sampler(
|
||||
trainset_u.train, train_sampler,
|
||||
trainset_t.train, train_sampler,
|
||||
batch_size=batch_size_train,
|
||||
num_instances=num_instances
|
||||
),
|
||||
|
@ -289,9 +289,9 @@ class ImageDataManager(DataManager):
|
|||
print(' # source ids : {}'.format(self.num_train_pids))
|
||||
print(' # source images : {}'.format(len(trainset)))
|
||||
print(' # source cameras : {}'.format(self.num_train_cams))
|
||||
if load_train_targets:
|
||||
print(' # target images : {} (unlabeled)'.format(len(trainset_t)))
|
||||
print(' target : {}'.format(self.targets))
|
||||
if load_unlabeled_targets:
|
||||
print(' # target images : {} (unlabeled)'.format(len(trainset_u)))
|
||||
print(' *****************************************')
|
||||
print('\n')
|
||||
|
||||
|
|
Loading…
Reference in New Issue