Multi-gpu compatibility; Load data to GPU

only when GPUs are visible
pull/19/head
garyfan 2019-04-16 18:35:18 +08:00
parent cfc5e9e17a
commit cbb86fcce2
2 changed files with 16 additions and 6 deletions

View File

@ -6,6 +6,7 @@
import logging
import torch
import torch.nn as nn
from ignite.engine import Engine
from utils.reid_metric import R1_mAP, R1_mAP_reranking
@ -25,13 +26,15 @@ def create_supervised_evaluator(model, metrics,
Engine: an evaluator engine with supervised inference function
"""
if device:
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.to(device)
def _inference(engine, batch):
model.eval()
with torch.no_grad():
data, pids, camids = batch
data = data.cuda()
data = data.to(device) if torch.cuda.device_count() >= 1 else data
feat = model(data)
return feat, pids, camids

View File

@ -7,6 +7,7 @@
import logging
import torch
import torch.nn as nn
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Timer
from ignite.metrics import RunningAverage
@ -30,14 +31,16 @@ def create_supervised_trainer(model, optimizer, loss_fn,
Engine: a trainer engine with supervised update function
"""
if device:
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.to(device)
def _update(engine, batch):
model.train()
optimizer.zero_grad()
img, target = batch
img = img.cuda()
target = target.cuda()
img = img.to(device) if torch.cuda.device_count() >= 1 else img
target = target.to(device) if torch.cuda.device_count() >= 1 else target
score, feat = model(img)
loss = loss_fn(score, feat, target)
loss.backward()
@ -65,6 +68,8 @@ def create_supervised_trainer_with_center(model, center_criterion, optimizer, op
Engine: a trainer engine with supervised update function
"""
if device:
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.to(device)
def _update(engine, batch):
@ -72,8 +77,8 @@ def create_supervised_trainer_with_center(model, center_criterion, optimizer, op
optimizer.zero_grad()
optimizer_center.zero_grad()
img, target = batch
img = img.cuda()
target = target.cuda()
img = img.to(device) if torch.cuda.device_count() >= 1 else img
target = target.to(device) if torch.cuda.device_count() >= 1 else target
score, feat = model(img)
loss = loss_fn(score, feat, target)
# print("Total loss is {}, center loss is {}".format(loss, center_criterion(feat, target)))
@ -104,13 +109,15 @@ def create_supervised_evaluator(model, metrics,
Engine: an evaluator engine with supervised inference function
"""
if device:
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.to(device)
def _inference(engine, batch):
model.eval()
with torch.no_grad():
data, pids, camids = batch
data = data.cuda()
data = data.to(device) if torch.cuda.device_count() >= 1 else data
feat = model(data)
return feat, pids, camids