mirror of https://github.com/JDAI-CV/fast-reid.git
76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: xyliao1993@qq.com
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import time
|
|
|
|
from utils.meters import AverageMeter
|
|
|
|
|
|
class BaseTrainer(object):
|
|
def __init__(self, opt, model, optimzier, criterion, summary_writer):
|
|
self.opt = opt
|
|
self.model = model
|
|
self.optimizer= optimzier
|
|
self.criterion = criterion
|
|
self.summary_writer = summary_writer
|
|
|
|
def train(self, epoch, data_loader):
|
|
self.model.train()
|
|
|
|
batch_time = AverageMeter()
|
|
data_time = AverageMeter()
|
|
losses = AverageMeter()
|
|
|
|
start = time.time()
|
|
for i, inputs in enumerate(data_loader):
|
|
data_time.update(time.time() - start)
|
|
|
|
# model optimizer
|
|
self._parse_data(inputs)
|
|
self._forward()
|
|
self.optimizer.zero_grad()
|
|
self._backward()
|
|
self.optimizer.step()
|
|
|
|
batch_time.update(time.time() - start)
|
|
losses.update(self.loss.item())
|
|
|
|
# tensorboard
|
|
global_step = epoch * len(data_loader) + i
|
|
self.summary_writer.add_scalar('loss', self.loss.item(), global_step)
|
|
self.summary_writer.add_scalar('lr', self.optimizer.param_groups[0]['lr'], global_step)
|
|
|
|
start = time.time()
|
|
|
|
if (i + 1) % self.opt.print_freq == 0:
|
|
print('Epoch: [{}][{}/{}]\t'
|
|
'Batch Time {:.3f} ({:.3f})\t'
|
|
'Data Time {:.3f} ({:.3f})\t'
|
|
'Loss {:.3f} ({:.3f})\t'
|
|
.format(epoch, i + 1, len(data_loader),
|
|
batch_time.val, batch_time.mean,
|
|
data_time.val, data_time.mean,
|
|
losses.val, losses.mean))
|
|
param_group = self.optimizer.param_groups
|
|
print('Epoch: [{}]\tEpoch Time {:.3f} s\tLoss {:.3f}\t'
|
|
'Lr {:.2e}'
|
|
.format(epoch, batch_time.sum, losses.mean, param_group[0]['lr']))
|
|
print()
|
|
|
|
def _parse_data(self, inputs):
|
|
raise NotImplementedError
|
|
|
|
def _forward(self):
|
|
raise NotImplementedError
|
|
|
|
def _backward(self):
|
|
raise NotImplementedError
|