fast-reid/bases/base_trainer.py

76 lines
2.3 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: xyliao1993@qq.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import time
from utils.meters import AverageMeter
class BaseTrainer(object):
def __init__(self, opt, model, optimzier, criterion, summary_writer):
self.opt = opt
self.model = model
self.optimizer= optimzier
self.criterion = criterion
self.summary_writer = summary_writer
def train(self, epoch, data_loader):
self.model.train()
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
start = time.time()
for i, inputs in enumerate(data_loader):
data_time.update(time.time() - start)
# model optimizer
self._parse_data(inputs)
self._forward()
self.optimizer.zero_grad()
self._backward()
self.optimizer.step()
batch_time.update(time.time() - start)
losses.update(self.loss.item())
# tensorboard
global_step = epoch * len(data_loader) + i
self.summary_writer.add_scalar('loss', self.loss.item(), global_step)
self.summary_writer.add_scalar('lr', self.optimizer.param_groups[0]['lr'], global_step)
start = time.time()
if (i + 1) % self.opt.print_freq == 0:
print('Epoch: [{}][{}/{}]\t'
'Batch Time {:.3f} ({:.3f})\t'
'Data Time {:.3f} ({:.3f})\t'
'Loss {:.3f} ({:.3f})\t'
.format(epoch, i + 1, len(data_loader),
batch_time.val, batch_time.mean,
data_time.val, data_time.mean,
losses.val, losses.mean))
param_group = self.optimizer.param_groups
print('Epoch: [{}]\tEpoch Time {:.3f} s\tLoss {:.3f}\t'
'Lr {:.2e}'
.format(epoch, batch_time.sum, losses.mean, param_group[0]['lr']))
print()
def _parse_data(self, inputs):
raise NotImplementedError
def _forward(self):
raise NotImplementedError
def _backward(self):
raise NotImplementedError