mv data_time after .cuda()

pull/62/head
KaiyangZhou 2018-05-23 11:42:32 +01:00
parent 49eed19f55
commit 8b53814783
6 changed files with 23 additions and 17 deletions

View File

@ -208,11 +208,12 @@ def train(epoch, model, criterion_xent, criterion_cent, optimizer_model, optimiz
end = time.time()
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
if use_gpu:
imgs, pids = imgs.cuda(), pids.cuda()
# measure data loading time
data_time.update(time.time() - end)
if use_gpu:
imgs, pids = imgs.cuda(), pids.cuda()
outputs, features = model(imgs)
xentloss = criterion_xent(outputs, pids)
centloss = criterion_cent(features, pids) * args.weight_cent

View File

@ -208,11 +208,12 @@ def train(epoch, model, criterion_xent, criterion_ring, optimizer, trainloader,
end = time.time()
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
# measure data loading time
data_time.update(time.time() - end)
if use_gpu:
imgs, pids = imgs.cuda(), pids.cuda()
# measure data loading time
data_time.update(time.time() - end)
outputs, features = model(imgs)
xentloss = criterion_xent(outputs, pids)
ringloss = criterion_ring(features)

View File

@ -203,11 +203,12 @@ def train(epoch, model, criterion, optimizer, trainloader, use_gpu):
end = time.time()
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
# measure data loading time
data_time.update(time.time() - end)
if use_gpu:
imgs, pids = imgs.cuda(), pids.cuda()
# measure data loading time
data_time.update(time.time() - end)
outputs = model(imgs)
if isinstance(outputs, tuple):
loss = DeepSupervision(criterion, outputs, pids)

View File

@ -212,11 +212,12 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
end = time.time()
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
# measure data loading time
data_time.update(time.time() - end)
if use_gpu:
imgs, pids = imgs.cuda(), pids.cuda()
# measure data loading time
data_time.update(time.time() - end)
outputs, features = model(imgs)
if args.htri_only:
if isinstance(features, tuple):

View File

@ -200,11 +200,12 @@ def train(epoch, model, criterion, optimizer, trainloader, use_gpu):
end = time.time()
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
# measure data loading time
data_time.update(time.time() - end)
if use_gpu:
imgs, pids = imgs.cuda(), pids.cuda()
# measure data loading time
data_time.update(time.time() - end)
outputs = model(imgs)
loss = criterion(outputs, pids)
optimizer.zero_grad()

View File

@ -209,11 +209,12 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
end = time.time()
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
# measure data loading time
data_time.update(time.time() - end)
if use_gpu:
imgs, pids = imgs.cuda(), pids.cuda()
# measure data loading time
data_time.update(time.time() - end)
outputs, features = model(imgs)
if args.htri_only:
# only use hard triplet loss to train the network