Multi-gpu compatibility; Load data to GPU

only when GPUs are visible
pull/19/head
garyfan 2019-04-16 18:35:18 +08:00
parent cfc5e9e17a
commit cbb86fcce2
2 changed files with 16 additions and 6 deletions

View File

@ -6,6 +6,7 @@
import logging import logging
import torch import torch
import torch.nn as nn
from ignite.engine import Engine from ignite.engine import Engine
from utils.reid_metric import R1_mAP, R1_mAP_reranking from utils.reid_metric import R1_mAP, R1_mAP_reranking
@ -25,13 +26,15 @@ def create_supervised_evaluator(model, metrics,
Engine: an evaluator engine with supervised inference function Engine: an evaluator engine with supervised inference function
""" """
if device: if device:
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.to(device) model.to(device)
def _inference(engine, batch): def _inference(engine, batch):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
data, pids, camids = batch data, pids, camids = batch
data = data.cuda() data = data.to(device) if torch.cuda.device_count() >= 1 else data
feat = model(data) feat = model(data)
return feat, pids, camids return feat, pids, camids

View File

@ -7,6 +7,7 @@
import logging import logging
import torch import torch
import torch.nn as nn
from ignite.engine import Engine, Events from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Timer from ignite.handlers import ModelCheckpoint, Timer
from ignite.metrics import RunningAverage from ignite.metrics import RunningAverage
@ -30,14 +31,16 @@ def create_supervised_trainer(model, optimizer, loss_fn,
Engine: a trainer engine with supervised update function Engine: a trainer engine with supervised update function
""" """
if device: if device:
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.to(device) model.to(device)
def _update(engine, batch): def _update(engine, batch):
model.train() model.train()
optimizer.zero_grad() optimizer.zero_grad()
img, target = batch img, target = batch
img = img.cuda() img = img.to(device) if torch.cuda.device_count() >= 1 else img
target = target.cuda() target = target.to(device) if torch.cuda.device_count() >= 1 else target
score, feat = model(img) score, feat = model(img)
loss = loss_fn(score, feat, target) loss = loss_fn(score, feat, target)
loss.backward() loss.backward()
@ -65,6 +68,8 @@ def create_supervised_trainer_with_center(model, center_criterion, optimizer, op
Engine: a trainer engine with supervised update function Engine: a trainer engine with supervised update function
""" """
if device: if device:
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.to(device) model.to(device)
def _update(engine, batch): def _update(engine, batch):
@ -72,8 +77,8 @@ def create_supervised_trainer_with_center(model, center_criterion, optimizer, op
optimizer.zero_grad() optimizer.zero_grad()
optimizer_center.zero_grad() optimizer_center.zero_grad()
img, target = batch img, target = batch
img = img.cuda() img = img.to(device) if torch.cuda.device_count() >= 1 else img
target = target.cuda() target = target.to(device) if torch.cuda.device_count() >= 1 else target
score, feat = model(img) score, feat = model(img)
loss = loss_fn(score, feat, target) loss = loss_fn(score, feat, target)
# print("Total loss is {}, center loss is {}".format(loss, center_criterion(feat, target))) # print("Total loss is {}, center loss is {}".format(loss, center_criterion(feat, target)))
@ -104,13 +109,15 @@ def create_supervised_evaluator(model, metrics,
Engine: an evaluator engine with supervised inference function Engine: an evaluator engine with supervised inference function
""" """
if device: if device:
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.to(device) model.to(device)
def _inference(engine, batch): def _inference(engine, batch):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
data, pids, camids = batch data, pids, camids = batch
data = data.cuda() data = data.to(device) if torch.cuda.device_count() >= 1 else data
feat = model(data) feat = model(data)
return feat, pids, camids return feat, pids, camids