PaddleOCR/benchmark/PaddleOCR_DBNet/trainer/trainer.py

231 lines
10 KiB
Python

# -*- coding: utf-8 -*-
# @Time : 2019/8/23 21:58
# @Author : zhoujun
import time
import paddle
from tqdm import tqdm
from base import BaseTrainer
from utils import runningScore, cal_text_score, Polynomial, profiler
class Trainer(BaseTrainer):
def __init__(self,
config,
model,
criterion,
train_loader,
validate_loader,
metric_cls,
post_process=None,
profiler_options=None):
super(Trainer, self).__init__(config, model, criterion, train_loader,
validate_loader, metric_cls, post_process)
self.profiler_options = profiler_options
self.enable_eval = config['trainer'].get('enable_eval', True)
def _train_epoch(self, epoch):
self.model.train()
total_samples = 0
train_reader_cost = 0.0
train_batch_cost = 0.0
reader_start = time.time()
epoch_start = time.time()
train_loss = 0.
running_metric_text = runningScore(2)
for i, batch in enumerate(self.train_loader):
profiler.add_profiler_step(self.profiler_options)
if i >= self.train_loader_len:
break
self.global_step += 1
lr = self.optimizer.get_lr()
cur_batch_size = batch['img'].shape[0]
train_reader_cost += time.time() - reader_start
if self.amp:
with paddle.amp.auto_cast(
enable='gpu' in paddle.device.get_device(),
custom_white_list=self.amp.get('custom_white_list', []),
custom_black_list=self.amp.get('custom_black_list', []),
level=self.amp.get('level', 'O2')):
preds = self.model(batch['img'])
loss_dict = self.criterion(preds.astype(paddle.float32), batch)
scaled_loss = self.amp['scaler'].scale(loss_dict['loss'])
scaled_loss.backward()
self.amp['scaler'].minimize(self.optimizer, scaled_loss)
else:
preds = self.model(batch['img'])
loss_dict = self.criterion(preds, batch)
# backward
loss_dict['loss'].backward()
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.clear_grad()
train_batch_time = time.time() - reader_start
train_batch_cost += train_batch_time
total_samples += cur_batch_size
# acc iou
score_shrink_map = cal_text_score(
preds[:, 0, :, :],
batch['shrink_map'],
batch['shrink_mask'],
running_metric_text,
thred=self.config['post_processing']['args']['thresh'])
# loss 和 acc 记录到日志
loss_str = 'loss: {:.4f}, '.format(loss_dict['loss'].item())
for idx, (key, value) in enumerate(loss_dict.items()):
loss_dict[key] = value.item()
if key == 'loss':
continue
loss_str += '{}: {:.4f}'.format(key, loss_dict[key])
if idx < len(loss_dict) - 1:
loss_str += ', '
train_loss += loss_dict['loss']
acc = score_shrink_map['Mean Acc']
iou_shrink_map = score_shrink_map['Mean IoU']
if self.global_step % self.log_iter == 0:
self.logger_info(
'[{}/{}], [{}/{}], global_step: {}, ips: {:.1f} samples/sec, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, acc: {:.4f}, iou_shrink_map: {:.4f}, {}lr:{:.6}, time:{:.2f}'.
format(epoch, self.epochs, i + 1, self.train_loader_len,
self.global_step, total_samples / train_batch_cost,
train_reader_cost / self.log_iter, train_batch_cost /
self.log_iter, total_samples / self.log_iter, acc,
iou_shrink_map, loss_str, lr, train_batch_cost))
total_samples = 0
train_reader_cost = 0.0
train_batch_cost = 0.0
if self.visualdl_enable and paddle.distributed.get_rank() == 0:
# write tensorboard
for key, value in loss_dict.items():
self.writer.add_scalar('TRAIN/LOSS/{}'.format(key), value,
self.global_step)
self.writer.add_scalar('TRAIN/ACC_IOU/acc', acc,
self.global_step)
self.writer.add_scalar('TRAIN/ACC_IOU/iou_shrink_map',
iou_shrink_map, self.global_step)
self.writer.add_scalar('TRAIN/lr', lr, self.global_step)
reader_start = time.time()
return {
'train_loss': train_loss / self.train_loader_len,
'lr': lr,
'time': time.time() - epoch_start,
'epoch': epoch
}
def _eval(self, epoch):
self.model.eval()
raw_metrics = []
total_frame = 0.0
total_time = 0.0
for i, batch in tqdm(
enumerate(self.validate_loader),
total=len(self.validate_loader),
desc='test model'):
with paddle.no_grad():
start = time.time()
if self.amp:
with paddle.amp.auto_cast(
enable='gpu' in paddle.device.get_device(),
custom_white_list=self.amp.get('custom_white_list',
[]),
custom_black_list=self.amp.get('custom_black_list',
[]),
level=self.amp.get('level', 'O2')):
preds = self.model(batch['img'])
preds = preds.astype(paddle.float32)
else:
preds = self.model(batch['img'])
boxes, scores = self.post_process(
batch,
preds,
is_output_polygon=self.metric_cls.is_output_polygon)
total_frame += batch['img'].shape[0]
total_time += time.time() - start
raw_metric = self.metric_cls.validate_measure(batch,
(boxes, scores))
raw_metrics.append(raw_metric)
metrics = self.metric_cls.gather_measure(raw_metrics)
self.logger_info('FPS:{}'.format(total_frame / total_time))
return metrics['recall'].avg, metrics['precision'].avg, metrics[
'fmeasure'].avg
def _on_epoch_finish(self):
self.logger_info('[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'.
format(self.epoch_result['epoch'], self.epochs, self.
epoch_result['train_loss'], self.epoch_result[
'time'], self.epoch_result['lr']))
net_save_path = '{}/model_latest.pth'.format(self.checkpoint_dir)
net_save_path_best = '{}/model_best.pth'.format(self.checkpoint_dir)
if paddle.distributed.get_rank() == 0:
self._save_checkpoint(self.epoch_result['epoch'], net_save_path)
save_best = False
if self.validate_loader is not None and self.metric_cls is not None and self.enable_eval: # 使用f1作为最优模型指标
recall, precision, hmean = self._eval(self.epoch_result[
'epoch'])
if self.visualdl_enable:
self.writer.add_scalar('EVAL/recall', recall,
self.global_step)
self.writer.add_scalar('EVAL/precision', precision,
self.global_step)
self.writer.add_scalar('EVAL/hmean', hmean,
self.global_step)
self.logger_info(
'test: recall: {:.6f}, precision: {:.6f}, hmean: {:.6f}'.
format(recall, precision, hmean))
if hmean >= self.metrics['hmean']:
save_best = True
self.metrics['train_loss'] = self.epoch_result['train_loss']
self.metrics['hmean'] = hmean
self.metrics['precision'] = precision
self.metrics['recall'] = recall
self.metrics['best_model_epoch'] = self.epoch_result[
'epoch']
else:
if self.epoch_result['train_loss'] <= self.metrics[
'train_loss']:
save_best = True
self.metrics['train_loss'] = self.epoch_result['train_loss']
self.metrics['best_model_epoch'] = self.epoch_result[
'epoch']
best_str = 'current best, '
for k, v in self.metrics.items():
best_str += '{}: {:.6f}, '.format(k, v)
self.logger_info(best_str)
if save_best:
import shutil
shutil.copy(net_save_path, net_save_path_best)
self.logger_info("Saving current best: {}".format(
net_save_path_best))
else:
self.logger_info("Saving checkpoint: {}".format(net_save_path))
def _on_train_finish(self):
if self.enable_eval:
for k, v in self.metrics.items():
self.logger_info('{}:{}'.format(k, v))
self.logger_info('finish train')
def _initialize_scheduler(self):
if self.config['lr_scheduler']['type'] == 'Polynomial':
self.config['lr_scheduler']['args']['epochs'] = self.config[
'trainer']['epochs']
self.config['lr_scheduler']['args']['step_each_epoch'] = len(
self.train_loader)
self.lr_scheduler = Polynomial(
**self.config['lr_scheduler']['args'])()
else:
self.lr_scheduler = self._initialize('lr_scheduler',
paddle.optimizer.lr)