[1.0.8] load_unlabeled_targets -> load_train_targets

pull/294/head
KaiyangZhou 2019-11-27 21:20:04 +00:00
parent d6405de241
commit 2095ee05c5
3 changed files with 18 additions and 18 deletions

View File

@ -38,7 +38,7 @@ You can find some research projects that are built on top of Torchreid `here <ht
What's new
---------------
- [Nov 19] ``ImageDataManager`` can load training data from target datasets by setting ``load_unlabeled_targets=True``, and the train-loader can be accessed with ``train_loader_u = datamanager.train_loader_u``. This feature is useful for domain adaptation research.
- [Nov 19] ``ImageDataManager`` can load training data from target datasets by setting ``load_train_targets=True``, and the train-loader can be accessed with ``train_loader_t = datamanager.train_loader_t``. This feature is useful for domain adaptation research.
Installation

View File

@ -1,7 +1,7 @@
from __future__ import absolute_import
from __future__ import print_function
__version__ = '1.0.7'
__version__ = '1.0.8'
__author__ = 'Kaiyang Zhou'
__homepage__ = 'https://kaiyangzhou.github.io/'
__description__ = 'Deep learning person re-identification in PyTorch'

View File

@ -93,8 +93,8 @@ class ImageDataManager(DataManager):
split_id (int, optional): split id (*0-based*). Default is 0.
combineall (bool, optional): combine train, query and gallery in a dataset for
training. Default is False.
load_unlabeled_targets (bool, optional): construct train loader for unlabeled target
datasets. Default is False.
load_train_targets (bool, optional): construct train-loader for target datasets.
Default is False. This is useful for domain adaptation research.
batch_size_train (int, optional): number of images in a training batch. Default is 32.
batch_size_test (int, optional): number of images in a test batch. Default is 32.
workers (int, optional): number of workers. Default is 4.
@ -126,7 +126,7 @@ class ImageDataManager(DataManager):
test_loader = datamanager.test_loader
# return train loader of target data
train_loader_u = datamanager.train_loader_u
train_loader_t = datamanager.train_loader_t
"""
data_type = 'image'
@ -143,7 +143,7 @@ class ImageDataManager(DataManager):
use_gpu=True,
split_id=0,
combineall=False,
load_unlabeled_targets=False,
load_train_targets=False,
batch_size_train=32,
batch_size_test=32,
workers=4,
@ -194,33 +194,33 @@ class ImageDataManager(DataManager):
drop_last=True
)
self.train_loader_u = None
if load_unlabeled_targets:
self.train_loader_t = None
if load_train_targets:
# check if sources and targets are identical
assert len(set(self.sources) & set(self.targets)) == 0, \
'sources={} and targets={} must not have overlap'.format(self.sources, self.targets)
print('=> Loading train (target) dataset')
trainset_u = []
trainset_t = []
for name in self.targets:
trainset_u_ = init_image_dataset(
trainset_t_ = init_image_dataset(
name,
transform=self.transform_tr,
mode='train',
combineall=combineall,
combineall=False, # only use the training data
root=root,
split_id=split_id,
cuhk03_labeled=cuhk03_labeled,
cuhk03_classic_split=cuhk03_classic_split,
market1501_500k=market1501_500k
)
trainset_u.append(trainset_u_)
trainset_u = sum(trainset_u)
trainset_t.append(trainset_t_)
trainset_t = sum(trainset_t)
self.train_loader_u = torch.utils.data.DataLoader(
trainset_u,
self.train_loader_t = torch.utils.data.DataLoader(
trainset_t,
sampler=build_train_sampler(
trainset_u.train, train_sampler,
trainset_t.train, train_sampler,
batch_size=batch_size_train,
num_instances=num_instances
),
@ -289,9 +289,9 @@ class ImageDataManager(DataManager):
print(' # source ids : {}'.format(self.num_train_pids))
print(' # source images : {}'.format(len(trainset)))
print(' # source cameras : {}'.format(self.num_train_cams))
if load_train_targets:
print(' # target images : {} (unlabeled)'.format(len(trainset_t)))
print(' target : {}'.format(self.targets))
if load_unlabeled_targets:
print(' # target images : {} (unlabeled)'.format(len(trainset_u)))
print(' *****************************************')
print('\n')