mirror of
https://github.com/KaiyangZhou/deep-person-reid.git
synced 2025-06-03 14:53:23 +08:00
return (imgs, pids, camids, img_paths) for image dataloaders
This commit is contained in:
parent
db157c1345
commit
8c53007138
@ -44,7 +44,7 @@ class ImageDataset(Dataset):
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
return img, pid, camid
|
||||
return img, pid, camid, img_path
|
||||
|
||||
|
||||
class VideoDataset(Dataset):
|
||||
|
@ -163,7 +163,7 @@ def train(epoch, model, criterion, optimizer, trainloader, use_gpu, fixbase=Fals
|
||||
open_all_layers(model)
|
||||
|
||||
end = time.time()
|
||||
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
|
||||
for batch_idx, (imgs, pids, _, _) in enumerate(trainloader):
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
if use_gpu:
|
||||
@ -200,7 +200,7 @@ def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20], retur
|
||||
|
||||
with torch.no_grad():
|
||||
qf, q_pids, q_camids = [], [], []
|
||||
for batch_idx, (imgs, pids, camids) in enumerate(queryloader):
|
||||
for batch_idx, (imgs, pids, camids, _) in enumerate(queryloader):
|
||||
if use_gpu: imgs = imgs.cuda()
|
||||
|
||||
end = time.time()
|
||||
@ -219,7 +219,7 @@ def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20], retur
|
||||
|
||||
gf, g_pids, g_camids = [], [], []
|
||||
end = time.time()
|
||||
for batch_idx, (imgs, pids, camids) in enumerate(galleryloader):
|
||||
for batch_idx, (imgs, pids, camids, _) in enumerate(galleryloader):
|
||||
if use_gpu: imgs = imgs.cuda()
|
||||
|
||||
end = time.time()
|
||||
|
@ -166,7 +166,7 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
|
||||
open_all_layers(model)
|
||||
|
||||
end = time.time()
|
||||
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
|
||||
for batch_idx, (imgs, pids, _, _) in enumerate(trainloader):
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
if use_gpu:
|
||||
@ -216,7 +216,7 @@ def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20], retur
|
||||
|
||||
with torch.no_grad():
|
||||
qf, q_pids, q_camids = [], [], []
|
||||
for batch_idx, (imgs, pids, camids) in enumerate(queryloader):
|
||||
for batch_idx, (imgs, pids, camids, _) in enumerate(queryloader):
|
||||
if use_gpu:
|
||||
imgs = imgs.cuda()
|
||||
|
||||
@ -235,7 +235,7 @@ def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20], retur
|
||||
print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))
|
||||
|
||||
gf, g_pids, g_camids = [], [], []
|
||||
for batch_idx, (imgs, pids, camids) in enumerate(galleryloader):
|
||||
for batch_idx, (imgs, pids, camids, _) in enumerate(galleryloader):
|
||||
if use_gpu:
|
||||
imgs = imgs.cuda()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user