2019-03-17 15:16:48 +08:00
|
|
|
# encoding: utf-8
|
|
|
|
"""
|
|
|
|
@author: sherlock
|
|
|
|
@contact: sherlockliao01@gmail.com
|
|
|
|
"""
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from ignite.engine import Engine, Events
|
|
|
|
from ignite.handlers import ModelCheckpoint, Timer
|
|
|
|
from ignite.metrics import RunningAverage
|
|
|
|
|
|
|
|
from utils.reid_metric import R1_mAP
|
|
|
|
|
|
|
|
|
|
|
|
def create_supervised_trainer(model, optimizer, loss_fn,
|
|
|
|
device=None):
|
|
|
|
"""
|
|
|
|
Factory function for creating a trainer for supervised models
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (`torch.nn.Module`): the model to train
|
|
|
|
optimizer (`torch.optim.Optimizer`): the optimizer to use
|
|
|
|
loss_fn (torch.nn loss function): the loss function to use
|
|
|
|
device (str, optional): device type specification (default: None).
|
|
|
|
Applies to both model and batches.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Engine: a trainer engine with supervised update function
|
|
|
|
"""
|
|
|
|
if device:
|
|
|
|
model.to(device)
|
|
|
|
|
|
|
|
def _update(engine, batch):
|
|
|
|
model.train()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
img, target = batch
|
|
|
|
img = img.cuda()
|
|
|
|
target = target.cuda()
|
|
|
|
score, feat = model(img)
|
|
|
|
loss = loss_fn(score, feat, target)
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
# compute acc
|
|
|
|
acc = (score.max(1)[1] == target).float().mean()
|
|
|
|
return loss.item(), acc.item()
|
|
|
|
|
|
|
|
return Engine(_update)
|
|
|
|
|
|
|
|
|
|
|
|
def create_supervised_trainer_with_center(model, center_criterion, optimizer, optimizer_center, loss_fn, cetner_loss_weight,
|
|
|
|
device=None):
|
|
|
|
"""
|
|
|
|
Factory function for creating a trainer for supervised models
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (`torch.nn.Module`): the model to train
|
|
|
|
optimizer (`torch.optim.Optimizer`): the optimizer to use
|
|
|
|
loss_fn (torch.nn loss function): the loss function to use
|
|
|
|
device (str, optional): device type specification (default: None).
|
|
|
|
Applies to both model and batches.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Engine: a trainer engine with supervised update function
|
|
|
|
"""
|
|
|
|
if device:
|
|
|
|
model.to(device)
|
|
|
|
|
|
|
|
def _update(engine, batch):
|
|
|
|
model.train()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
optimizer_center.zero_grad()
|
|
|
|
img, target = batch
|
|
|
|
img = img.cuda()
|
|
|
|
target = target.cuda()
|
|
|
|
score, feat = model(img)
|
|
|
|
loss = loss_fn(score, feat, target)
|
|
|
|
# print("Total loss is {}, center loss is {}".format(loss, center_criterion(feat, target)))
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
for param in center_criterion.parameters():
|
|
|
|
param.grad.data *= (1. / cetner_loss_weight)
|
|
|
|
optimizer_center.step()
|
|
|
|
|
|
|
|
# compute acc
|
|
|
|
acc = (score.max(1)[1] == target).float().mean()
|
|
|
|
return loss.item(), acc.item()
|
|
|
|
|
|
|
|
return Engine(_update)
|
|
|
|
|
|
|
|
|
|
|
|
def create_supervised_evaluator(model, metrics,
|
|
|
|
device=None):
|
|
|
|
"""
|
|
|
|
Factory function for creating an evaluator for supervised models
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (`torch.nn.Module`): the model to train
|
|
|
|
metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics
|
|
|
|
device (str, optional): device type specification (default: None).
|
|
|
|
Applies to both model and batches.
|
|
|
|
Returns:
|
|
|
|
Engine: an evaluator engine with supervised inference function
|
|
|
|
"""
|
|
|
|
if device:
|
|
|
|
model.to(device)
|
|
|
|
|
|
|
|
def _inference(engine, batch):
|
|
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
|
|
data, pids, camids = batch
|
|
|
|
data = data.cuda()
|
|
|
|
feat = model(data)
|
|
|
|
return feat, pids, camids
|
|
|
|
|
|
|
|
engine = Engine(_inference)
|
|
|
|
|
|
|
|
for name, metric in metrics.items():
|
|
|
|
metric.attach(engine, name)
|
|
|
|
|
|
|
|
return engine
|
|
|
|
|
|
|
|
|
|
|
|
def do_train(
|
|
|
|
cfg,
|
|
|
|
model,
|
|
|
|
train_loader,
|
|
|
|
val_loader,
|
|
|
|
optimizer,
|
|
|
|
scheduler,
|
|
|
|
loss_fn,
|
2019-03-26 16:22:05 +08:00
|
|
|
num_query,
|
|
|
|
start_epoch
|
2019-03-17 15:16:48 +08:00
|
|
|
):
|
|
|
|
log_period = cfg.SOLVER.LOG_PERIOD
|
|
|
|
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
|
|
|
|
eval_period = cfg.SOLVER.EVAL_PERIOD
|
|
|
|
output_dir = cfg.OUTPUT_DIR
|
|
|
|
device = cfg.MODEL.DEVICE
|
|
|
|
epochs = cfg.SOLVER.MAX_EPOCHS
|
|
|
|
|
|
|
|
logger = logging.getLogger("reid_baseline.train")
|
|
|
|
logger.info("Start training")
|
|
|
|
trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device)
|
|
|
|
evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device)
|
|
|
|
checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
|
|
|
|
timer = Timer(average=True)
|
|
|
|
|
|
|
|
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
|
|
|
|
'optimizer': optimizer.state_dict()})
|
|
|
|
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
|
|
|
|
pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
|
|
|
|
|
|
|
|
# average metric to attach on trainer
|
|
|
|
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
|
|
|
|
RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
|
|
|
|
|
2019-03-26 16:22:05 +08:00
|
|
|
@trainer.on(Events.STARTED)
|
|
|
|
def start_training(engine):
|
|
|
|
engine.state.epoch = start_epoch
|
|
|
|
|
2019-03-17 15:16:48 +08:00
|
|
|
@trainer.on(Events.EPOCH_STARTED)
|
|
|
|
def adjust_learning_rate(engine):
|
|
|
|
scheduler.step()
|
|
|
|
|
|
|
|
@trainer.on(Events.ITERATION_COMPLETED)
|
|
|
|
def log_training_loss(engine):
|
|
|
|
iter = (engine.state.iteration - 1) % len(train_loader) + 1
|
|
|
|
|
|
|
|
if iter % log_period == 0:
|
|
|
|
logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
|
|
|
|
.format(engine.state.epoch, iter, len(train_loader),
|
|
|
|
engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
|
|
|
|
scheduler.get_lr()[0]))
|
|
|
|
|
|
|
|
# adding handlers using `trainer.on` decorator API
|
|
|
|
@trainer.on(Events.EPOCH_COMPLETED)
|
|
|
|
def print_times(engine):
|
|
|
|
logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
|
|
|
|
.format(engine.state.epoch, timer.value() * timer.step_count,
|
|
|
|
train_loader.batch_size / timer.value()))
|
|
|
|
logger.info('-' * 10)
|
|
|
|
timer.reset()
|
|
|
|
|
|
|
|
@trainer.on(Events.EPOCH_COMPLETED)
|
|
|
|
def log_validation_results(engine):
|
|
|
|
if engine.state.epoch % eval_period == 0:
|
|
|
|
evaluator.run(val_loader)
|
|
|
|
cmc, mAP = evaluator.state.metrics['r1_mAP']
|
|
|
|
logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
|
|
|
|
logger.info("mAP: {:.1%}".format(mAP))
|
|
|
|
for r in [1, 5, 10]:
|
|
|
|
logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
|
|
|
|
|
|
|
|
trainer.run(train_loader, max_epochs=epochs)
|
|
|
|
|
|
|
|
|
|
|
|
def do_train_with_center(
|
|
|
|
cfg,
|
|
|
|
model,
|
|
|
|
center_criterion,
|
|
|
|
train_loader,
|
|
|
|
val_loader,
|
|
|
|
optimizer,
|
|
|
|
optimizer_center,
|
|
|
|
scheduler,
|
|
|
|
loss_fn,
|
2019-03-26 16:22:05 +08:00
|
|
|
num_query,
|
|
|
|
start_epoch
|
2019-03-17 15:16:48 +08:00
|
|
|
):
|
|
|
|
log_period = cfg.SOLVER.LOG_PERIOD
|
|
|
|
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
|
|
|
|
eval_period = cfg.SOLVER.EVAL_PERIOD
|
|
|
|
output_dir = cfg.OUTPUT_DIR
|
|
|
|
device = cfg.MODEL.DEVICE
|
|
|
|
epochs = cfg.SOLVER.MAX_EPOCHS
|
|
|
|
|
|
|
|
logger = logging.getLogger("reid_baseline.train")
|
|
|
|
logger.info("Start training")
|
|
|
|
trainer = create_supervised_trainer_with_center(model, center_criterion, optimizer, optimizer_center, loss_fn, cfg.SOLVER.CENTER_LOSS_WEIGHT, device=device)
|
|
|
|
evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device)
|
|
|
|
checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
|
|
|
|
timer = Timer(average=True)
|
|
|
|
|
|
|
|
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
|
2019-03-26 16:22:05 +08:00
|
|
|
'optimizer': optimizer.state_dict(),
|
|
|
|
'optimizer_center': optimizer_center.state_dict()})
|
|
|
|
|
2019-03-17 15:16:48 +08:00
|
|
|
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
|
|
|
|
pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
|
|
|
|
|
|
|
|
# average metric to attach on trainer
|
|
|
|
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
|
|
|
|
RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
|
|
|
|
|
2019-03-26 16:22:05 +08:00
|
|
|
@trainer.on(Events.STARTED)
|
|
|
|
def start_training(engine):
|
|
|
|
engine.state.epoch = start_epoch
|
|
|
|
|
2019-03-17 15:16:48 +08:00
|
|
|
@trainer.on(Events.EPOCH_STARTED)
|
|
|
|
def adjust_learning_rate(engine):
|
|
|
|
scheduler.step()
|
|
|
|
|
|
|
|
@trainer.on(Events.ITERATION_COMPLETED)
|
|
|
|
def log_training_loss(engine):
|
|
|
|
iter = (engine.state.iteration - 1) % len(train_loader) + 1
|
|
|
|
|
|
|
|
if iter % log_period == 0:
|
|
|
|
logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
|
|
|
|
.format(engine.state.epoch, iter, len(train_loader),
|
|
|
|
engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
|
|
|
|
scheduler.get_lr()[0]))
|
|
|
|
|
|
|
|
# adding handlers using `trainer.on` decorator API
|
|
|
|
@trainer.on(Events.EPOCH_COMPLETED)
|
|
|
|
def print_times(engine):
|
|
|
|
logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
|
|
|
|
.format(engine.state.epoch, timer.value() * timer.step_count,
|
|
|
|
train_loader.batch_size / timer.value()))
|
|
|
|
logger.info('-' * 10)
|
|
|
|
timer.reset()
|
|
|
|
|
|
|
|
@trainer.on(Events.EPOCH_COMPLETED)
|
|
|
|
def log_validation_results(engine):
|
|
|
|
if engine.state.epoch % eval_period == 0:
|
|
|
|
evaluator.run(val_loader)
|
|
|
|
cmc, mAP = evaluator.state.metrics['r1_mAP']
|
|
|
|
logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
|
|
|
|
logger.info("mAP: {:.1%}".format(mAP))
|
|
|
|
for r in [1, 5, 10]:
|
|
|
|
logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
|
|
|
|
|
|
|
|
trainer.run(train_loader, max_epochs=epochs)
|