parent
cfc5e9e17a
commit
cbb86fcce2
|
@ -6,6 +6,7 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ignite.engine import Engine
|
||||
|
||||
from utils.reid_metric import R1_mAP, R1_mAP_reranking
|
||||
|
@ -25,13 +26,15 @@ def create_supervised_evaluator(model, metrics,
|
|||
Engine: an evaluator engine with supervised inference function
|
||||
"""
|
||||
if device:
|
||||
if torch.cuda.device_count() > 1:
|
||||
model = nn.DataParallel(model)
|
||||
model.to(device)
|
||||
|
||||
def _inference(engine, batch):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
data, pids, camids = batch
|
||||
data = data.cuda()
|
||||
data = data.to(device) if torch.cuda.device_count() >= 1 else data
|
||||
feat = model(data)
|
||||
return feat, pids, camids
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ignite.engine import Engine, Events
|
||||
from ignite.handlers import ModelCheckpoint, Timer
|
||||
from ignite.metrics import RunningAverage
|
||||
|
@ -30,14 +31,16 @@ def create_supervised_trainer(model, optimizer, loss_fn,
|
|||
Engine: a trainer engine with supervised update function
|
||||
"""
|
||||
if device:
|
||||
if torch.cuda.device_count() > 1:
|
||||
model = nn.DataParallel(model)
|
||||
model.to(device)
|
||||
|
||||
def _update(engine, batch):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
img, target = batch
|
||||
img = img.cuda()
|
||||
target = target.cuda()
|
||||
img = img.to(device) if torch.cuda.device_count() >= 1 else img
|
||||
target = target.to(device) if torch.cuda.device_count() >= 1 else target
|
||||
score, feat = model(img)
|
||||
loss = loss_fn(score, feat, target)
|
||||
loss.backward()
|
||||
|
@ -65,6 +68,8 @@ def create_supervised_trainer_with_center(model, center_criterion, optimizer, op
|
|||
Engine: a trainer engine with supervised update function
|
||||
"""
|
||||
if device:
|
||||
if torch.cuda.device_count() > 1:
|
||||
model = nn.DataParallel(model)
|
||||
model.to(device)
|
||||
|
||||
def _update(engine, batch):
|
||||
|
@ -72,8 +77,8 @@ def create_supervised_trainer_with_center(model, center_criterion, optimizer, op
|
|||
optimizer.zero_grad()
|
||||
optimizer_center.zero_grad()
|
||||
img, target = batch
|
||||
img = img.cuda()
|
||||
target = target.cuda()
|
||||
img = img.to(device) if torch.cuda.device_count() >= 1 else img
|
||||
target = target.to(device) if torch.cuda.device_count() >= 1 else target
|
||||
score, feat = model(img)
|
||||
loss = loss_fn(score, feat, target)
|
||||
# print("Total loss is {}, center loss is {}".format(loss, center_criterion(feat, target)))
|
||||
|
@ -104,13 +109,15 @@ def create_supervised_evaluator(model, metrics,
|
|||
Engine: an evaluator engine with supervised inference function
|
||||
"""
|
||||
if device:
|
||||
if torch.cuda.device_count() > 1:
|
||||
model = nn.DataParallel(model)
|
||||
model.to(device)
|
||||
|
||||
def _inference(engine, batch):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
data, pids, camids = batch
|
||||
data = data.cuda()
|
||||
data = data.to(device) if torch.cuda.device_count() >= 1 else data
|
||||
feat = model(data)
|
||||
return feat, pids, camids
|
||||
|
||||
|
|
Loading…
Reference in New Issue