mirror of
https://github.com/KaiyangZhou/deep-person-reid.git
synced 2025-06-03 14:53:23 +08:00
market_test
This commit is contained in:
parent
be4383b3ca
commit
95ec05290a
@ -6,8 +6,8 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
datamanager = torchreid.data.ImageDataManager(
|
datamanager = torchreid.data.ImageDataManager(
|
||||||
root='reid-data',
|
root='reid-data',
|
||||||
sources='newdataset',
|
sources='market1501',
|
||||||
# targets='newdataset',
|
targets='market1501',
|
||||||
height=256,
|
height=256,
|
||||||
width=128,
|
width=128,
|
||||||
transforms=['random_flip', 'color_jitter']
|
transforms=['random_flip', 'color_jitter']
|
||||||
@ -15,7 +15,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
model = torchreid.models.build_model(
|
model = torchreid.models.build_model(
|
||||||
name='osnet_x1_0',
|
name='osnet_x1_0',
|
||||||
num_classes=0,
|
num_classes=datamanager._num_train_pids,
|
||||||
# loss='softmax',
|
# loss='softmax',
|
||||||
pretrained=True
|
pretrained=True
|
||||||
)
|
)
|
||||||
@ -39,12 +39,12 @@ if __name__ == '__main__':
|
|||||||
)
|
)
|
||||||
|
|
||||||
engine.run(
|
engine.run(
|
||||||
save_dir='log/newdataset_visrank',
|
save_dir='log/market_visrank',
|
||||||
test_only=True,
|
test_only=True,
|
||||||
# max_epoch=60,
|
# max_epoch=60,
|
||||||
# eval_freq=10,
|
# eval_freq=10,
|
||||||
# print_freq=10,
|
# print_freq=10,
|
||||||
visrank=True,
|
# visrank=True,
|
||||||
visrank_topk=20,
|
# visrank_topk=20,
|
||||||
# rerank=True
|
# rerank=True
|
||||||
)
|
)
|
||||||
|
@ -189,84 +189,84 @@ class ImageDataManager(DataManager):
|
|||||||
use_gpu=use_gpu
|
use_gpu=use_gpu
|
||||||
)
|
)
|
||||||
|
|
||||||
# print('=> Loading train (source) dataset')
|
print('=> Loading train (source) dataset')
|
||||||
# trainset = []
|
trainset = []
|
||||||
# for name in self.sources:
|
for name in self.sources:
|
||||||
# trainset_ = init_image_dataset(
|
trainset_ = init_image_dataset(
|
||||||
# name,
|
name,
|
||||||
# transform=self.transform_tr,
|
transform=self.transform_tr,
|
||||||
# k_tfm=k_tfm,
|
k_tfm=k_tfm,
|
||||||
# mode='train',
|
mode='train',
|
||||||
# combineall=combineall,
|
combineall=combineall,
|
||||||
# root=root,
|
root=root,
|
||||||
# split_id=split_id,
|
split_id=split_id,
|
||||||
# cuhk03_labeled=cuhk03_labeled,
|
cuhk03_labeled=cuhk03_labeled,
|
||||||
# cuhk03_classic_split=cuhk03_classic_split,
|
cuhk03_classic_split=cuhk03_classic_split,
|
||||||
# market1501_500k=market1501_500k
|
market1501_500k=market1501_500k
|
||||||
# )
|
)
|
||||||
# trainset.append(trainset_)
|
trainset.append(trainset_)
|
||||||
# trainset = sum(trainset)
|
trainset = sum(trainset)
|
||||||
|
|
||||||
# self._num_train_pids = trainset.num_train_pids
|
self._num_train_pids = trainset.num_train_pids
|
||||||
# self._num_train_cams = trainset.num_train_cams
|
self._num_train_cams = trainset.num_train_cams
|
||||||
|
|
||||||
# self.train_loader = torch.utils.data.DataLoader(
|
self.train_loader = torch.utils.data.DataLoader(
|
||||||
# trainset,
|
trainset,
|
||||||
# sampler=build_train_sampler(
|
sampler=build_train_sampler(
|
||||||
# trainset.train,
|
trainset.train,
|
||||||
# train_sampler,
|
train_sampler,
|
||||||
# batch_size=batch_size_train,
|
batch_size=batch_size_train,
|
||||||
# num_instances=num_instances,
|
num_instances=num_instances,
|
||||||
# num_cams=num_cams,
|
num_cams=num_cams,
|
||||||
# num_datasets=num_datasets
|
num_datasets=num_datasets
|
||||||
# ),
|
),
|
||||||
# batch_size=batch_size_train,
|
batch_size=batch_size_train,
|
||||||
# shuffle=False,
|
shuffle=False,
|
||||||
# num_workers=workers,
|
num_workers=workers,
|
||||||
# pin_memory=self.use_gpu,
|
pin_memory=self.use_gpu,
|
||||||
# drop_last=True
|
drop_last=True
|
||||||
# )
|
)
|
||||||
|
|
||||||
# self.train_loader_t = None
|
self.train_loader_t = None
|
||||||
# if load_train_targets:
|
if load_train_targets:
|
||||||
# # check if sources and targets are identical
|
# check if sources and targets are identical
|
||||||
# assert len(set(self.sources) & set(self.targets)) == 0, \
|
assert len(set(self.sources) & set(self.targets)) == 0, \
|
||||||
# 'sources={} and targets={} must not have overlap'.format(self.sources, self.targets)
|
'sources={} and targets={} must not have overlap'.format(self.sources, self.targets)
|
||||||
|
|
||||||
# print('=> Loading train (target) dataset')
|
print('=> Loading train (target) dataset')
|
||||||
# trainset_t = []
|
trainset_t = []
|
||||||
# for name in self.targets:
|
for name in self.targets:
|
||||||
# trainset_t_ = init_image_dataset(
|
trainset_t_ = init_image_dataset(
|
||||||
# name,
|
name,
|
||||||
# transform=self.transform_tr,
|
transform=self.transform_tr,
|
||||||
# k_tfm=k_tfm,
|
k_tfm=k_tfm,
|
||||||
# mode='train',
|
mode='train',
|
||||||
# combineall=False, # only use the training data
|
combineall=False, # only use the training data
|
||||||
# root=root,
|
root=root,
|
||||||
# split_id=split_id,
|
split_id=split_id,
|
||||||
# cuhk03_labeled=cuhk03_labeled,
|
cuhk03_labeled=cuhk03_labeled,
|
||||||
# cuhk03_classic_split=cuhk03_classic_split,
|
cuhk03_classic_split=cuhk03_classic_split,
|
||||||
# market1501_500k=market1501_500k
|
market1501_500k=market1501_500k
|
||||||
# )
|
)
|
||||||
# trainset_t.append(trainset_t_)
|
trainset_t.append(trainset_t_)
|
||||||
# trainset_t = sum(trainset_t)
|
trainset_t = sum(trainset_t)
|
||||||
|
|
||||||
# self.train_loader_t = torch.utils.data.DataLoader(
|
self.train_loader_t = torch.utils.data.DataLoader(
|
||||||
# trainset_t,
|
trainset_t,
|
||||||
# sampler=build_train_sampler(
|
sampler=build_train_sampler(
|
||||||
# trainset_t.train,
|
trainset_t.train,
|
||||||
# train_sampler_t,
|
train_sampler_t,
|
||||||
# batch_size=batch_size_train,
|
batch_size=batch_size_train,
|
||||||
# num_instances=num_instances,
|
num_instances=num_instances,
|
||||||
# num_cams=num_cams,
|
num_cams=num_cams,
|
||||||
# num_datasets=num_datasets
|
num_datasets=num_datasets
|
||||||
# ),
|
),
|
||||||
# batch_size=batch_size_train,
|
batch_size=batch_size_train,
|
||||||
# shuffle=False,
|
shuffle=False,
|
||||||
# num_workers=workers,
|
num_workers=workers,
|
||||||
# pin_memory=self.use_gpu,
|
pin_memory=self.use_gpu,
|
||||||
# drop_last=True
|
drop_last=True
|
||||||
# )
|
)
|
||||||
|
|
||||||
print('=> Loading test (target) dataset')
|
print('=> Loading test (target) dataset')
|
||||||
self.test_loader = {
|
self.test_loader = {
|
||||||
@ -335,9 +335,9 @@ class ImageDataManager(DataManager):
|
|||||||
print(' **************** Summary ****************')
|
print(' **************** Summary ****************')
|
||||||
print(' source : {}'.format(self.sources))
|
print(' source : {}'.format(self.sources))
|
||||||
print(' # source datasets : {}'.format(len(self.sources)))
|
print(' # source datasets : {}'.format(len(self.sources)))
|
||||||
# print(' # source ids : {}'.format(self.num_train_pids))
|
print(' # source ids : {}'.format(self.num_train_pids))
|
||||||
# print(' # source images : {}'.format(len(trainset)))
|
print(' # source images : {}'.format(len(trainset)))
|
||||||
# print(' # source cameras : {}'.format(self.num_train_cams))
|
print(' # source cameras : {}'.format(self.num_train_cams))
|
||||||
if load_train_targets:
|
if load_train_targets:
|
||||||
print(
|
print(
|
||||||
' # target images : {} (unlabeled)'.format(len(trainset_t))
|
' # target images : {} (unlabeled)'.format(len(trainset_t))
|
||||||
|
@ -54,8 +54,8 @@ class Dataset(object):
|
|||||||
# extend 3-tuple (img_path(s), pid, camid) to
|
# extend 3-tuple (img_path(s), pid, camid) to
|
||||||
# 4-tuple (img_path(s), pid, camid, dsetid) by
|
# 4-tuple (img_path(s), pid, camid, dsetid) by
|
||||||
# adding a dataset indicator "dsetid"
|
# adding a dataset indicator "dsetid"
|
||||||
# if len(train[0]) == 3:
|
if len(train[0]) == 3:
|
||||||
# train = [(*items, 0) for items in train]
|
train = [(*items, 0) for items in train]
|
||||||
if len(query[0]) == 3:
|
if len(query[0]) == 3:
|
||||||
query = [(*items, 0) for items in query]
|
query = [(*items, 0) for items in query]
|
||||||
if len(gallery[0]) == 3:
|
if len(gallery[0]) == 3:
|
||||||
|
@ -27,7 +27,7 @@ class Engine(object):
|
|||||||
|
|
||||||
def __init__(self, datamanager, use_gpu=True):
|
def __init__(self, datamanager, use_gpu=True):
|
||||||
self.datamanager = datamanager
|
self.datamanager = datamanager
|
||||||
# self.train_loader = self.datamanager.train_loader
|
self.train_loader = self.datamanager.train_loader
|
||||||
self.test_loader = self.datamanager.test_loader
|
self.test_loader = self.datamanager.test_loader
|
||||||
self.use_gpu = (torch.cuda.is_available() and use_gpu)
|
self.use_gpu = (torch.cuda.is_available() and use_gpu)
|
||||||
self.writer = None
|
self.writer = None
|
||||||
@ -394,10 +394,10 @@ class Engine(object):
|
|||||||
print(
|
print(
|
||||||
'Computing distance matrix with metric={} ...'.format(dist_metric)
|
'Computing distance matrix with metric={} ...'.format(dist_metric)
|
||||||
)
|
)
|
||||||
print(len(qf[0]))
|
# print(len(qf[0]))
|
||||||
distmat = metrics.compute_distance_matrix(qf, gf, dist_metric)
|
distmat = metrics.compute_distance_matrix(qf, gf, dist_metric)
|
||||||
distmat = distmat.numpy()
|
distmat = distmat.numpy()
|
||||||
print(distmat)
|
# print(distmat)
|
||||||
if rerank:
|
if rerank:
|
||||||
print('Applying person re-ranking ...')
|
print('Applying person re-ranking ...')
|
||||||
distmat_qq = metrics.compute_distance_matrix(qf, qf, dist_metric)
|
distmat_qq = metrics.compute_distance_matrix(qf, qf, dist_metric)
|
||||||
|
@ -70,7 +70,9 @@ class ImageSoftmaxEngine(Engine):
|
|||||||
self.register_model('model', model, optimizer, scheduler)
|
self.register_model('model', model, optimizer, scheduler)
|
||||||
|
|
||||||
self.criterion = CrossEntropyLoss(
|
self.criterion = CrossEntropyLoss(
|
||||||
num_classes=0, use_gpu=self.use_gpu, label_smooth=label_smooth
|
num_classes=self.datamanager.num_train_pids,
|
||||||
|
use_gpu=self.use_gpu,
|
||||||
|
label_smooth=label_smooth
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_backward(self, data):
|
def forward_backward(self, data):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user