parent
cfc5e9e17a
commit
cbb86fcce2
|
@ -6,6 +6,7 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from ignite.engine import Engine
|
from ignite.engine import Engine
|
||||||
|
|
||||||
from utils.reid_metric import R1_mAP, R1_mAP_reranking
|
from utils.reid_metric import R1_mAP, R1_mAP_reranking
|
||||||
|
@ -25,13 +26,15 @@ def create_supervised_evaluator(model, metrics,
|
||||||
Engine: an evaluator engine with supervised inference function
|
Engine: an evaluator engine with supervised inference function
|
||||||
"""
|
"""
|
||||||
if device:
|
if device:
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
model = nn.DataParallel(model)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
def _inference(engine, batch):
|
def _inference(engine, batch):
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
data, pids, camids = batch
|
data, pids, camids = batch
|
||||||
data = data.cuda()
|
data = data.to(device) if torch.cuda.device_count() >= 1 else data
|
||||||
feat = model(data)
|
feat = model(data)
|
||||||
return feat, pids, camids
|
return feat, pids, camids
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from ignite.engine import Engine, Events
|
from ignite.engine import Engine, Events
|
||||||
from ignite.handlers import ModelCheckpoint, Timer
|
from ignite.handlers import ModelCheckpoint, Timer
|
||||||
from ignite.metrics import RunningAverage
|
from ignite.metrics import RunningAverage
|
||||||
|
@ -30,14 +31,16 @@ def create_supervised_trainer(model, optimizer, loss_fn,
|
||||||
Engine: a trainer engine with supervised update function
|
Engine: a trainer engine with supervised update function
|
||||||
"""
|
"""
|
||||||
if device:
|
if device:
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
model = nn.DataParallel(model)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
def _update(engine, batch):
|
def _update(engine, batch):
|
||||||
model.train()
|
model.train()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
img, target = batch
|
img, target = batch
|
||||||
img = img.cuda()
|
img = img.to(device) if torch.cuda.device_count() >= 1 else img
|
||||||
target = target.cuda()
|
target = target.to(device) if torch.cuda.device_count() >= 1 else target
|
||||||
score, feat = model(img)
|
score, feat = model(img)
|
||||||
loss = loss_fn(score, feat, target)
|
loss = loss_fn(score, feat, target)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
@ -65,6 +68,8 @@ def create_supervised_trainer_with_center(model, center_criterion, optimizer, op
|
||||||
Engine: a trainer engine with supervised update function
|
Engine: a trainer engine with supervised update function
|
||||||
"""
|
"""
|
||||||
if device:
|
if device:
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
model = nn.DataParallel(model)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
def _update(engine, batch):
|
def _update(engine, batch):
|
||||||
|
@ -72,8 +77,8 @@ def create_supervised_trainer_with_center(model, center_criterion, optimizer, op
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
optimizer_center.zero_grad()
|
optimizer_center.zero_grad()
|
||||||
img, target = batch
|
img, target = batch
|
||||||
img = img.cuda()
|
img = img.to(device) if torch.cuda.device_count() >= 1 else img
|
||||||
target = target.cuda()
|
target = target.to(device) if torch.cuda.device_count() >= 1 else target
|
||||||
score, feat = model(img)
|
score, feat = model(img)
|
||||||
loss = loss_fn(score, feat, target)
|
loss = loss_fn(score, feat, target)
|
||||||
# print("Total loss is {}, center loss is {}".format(loss, center_criterion(feat, target)))
|
# print("Total loss is {}, center loss is {}".format(loss, center_criterion(feat, target)))
|
||||||
|
@ -104,13 +109,15 @@ def create_supervised_evaluator(model, metrics,
|
||||||
Engine: an evaluator engine with supervised inference function
|
Engine: an evaluator engine with supervised inference function
|
||||||
"""
|
"""
|
||||||
if device:
|
if device:
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
model = nn.DataParallel(model)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
def _inference(engine, batch):
|
def _inference(engine, batch):
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
data, pids, camids = batch
|
data, pids, camids = batch
|
||||||
data = data.cuda()
|
data = data.to(device) if torch.cuda.device_count() >= 1 else data
|
||||||
feat = model(data)
|
feat = model(data)
|
||||||
return feat, pids, camids
|
return feat, pids, camids
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue