return (imgs, pids, camids, img_paths) for image dataloaders

This commit is contained in:
KaiyangZhou 2018-11-10 11:54:06 +00:00
parent db157c1345
commit 8c53007138
3 changed files with 7 additions and 7 deletions

View File

@ -44,7 +44,7 @@ class ImageDataset(Dataset):
if self.transform is not None:
img = self.transform(img)
return img, pid, camid
return img, pid, camid, img_path
class VideoDataset(Dataset):

View File

@ -163,7 +163,7 @@ def train(epoch, model, criterion, optimizer, trainloader, use_gpu, fixbase=Fals
open_all_layers(model)
end = time.time()
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
for batch_idx, (imgs, pids, _, _) in enumerate(trainloader):
data_time.update(time.time() - end)
if use_gpu:
@ -200,7 +200,7 @@ def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20], retur
with torch.no_grad():
qf, q_pids, q_camids = [], [], []
for batch_idx, (imgs, pids, camids) in enumerate(queryloader):
for batch_idx, (imgs, pids, camids, _) in enumerate(queryloader):
if use_gpu: imgs = imgs.cuda()
end = time.time()
@ -219,7 +219,7 @@ def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20], retur
gf, g_pids, g_camids = [], [], []
end = time.time()
for batch_idx, (imgs, pids, camids) in enumerate(galleryloader):
for batch_idx, (imgs, pids, camids, _) in enumerate(galleryloader):
if use_gpu: imgs = imgs.cuda()
end = time.time()

View File

@ -166,7 +166,7 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
open_all_layers(model)
end = time.time()
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
for batch_idx, (imgs, pids, _, _) in enumerate(trainloader):
data_time.update(time.time() - end)
if use_gpu:
@ -216,7 +216,7 @@ def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20], retur
with torch.no_grad():
qf, q_pids, q_camids = [], [], []
for batch_idx, (imgs, pids, camids) in enumerate(queryloader):
for batch_idx, (imgs, pids, camids, _) in enumerate(queryloader):
if use_gpu:
imgs = imgs.cuda()
@ -235,7 +235,7 @@ def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20], retur
print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))
gf, g_pids, g_camids = [], [], []
for batch_idx, (imgs, pids, camids) in enumerate(galleryloader):
for batch_idx, (imgs, pids, camids, _) in enumerate(galleryloader):
if use_gpu:
imgs = imgs.cuda()