251 lines
9.1 KiB
Python
251 lines
9.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Time : 2019/8/23 21:50
|
|
# @Author : zhoujun
|
|
|
|
import os
|
|
import pathlib
|
|
import shutil
|
|
from pprint import pformat
|
|
|
|
import anyconfig
|
|
import paddle
|
|
import numpy as np
|
|
import random
|
|
from paddle.jit import to_static
|
|
from paddle.static import InputSpec
|
|
|
|
from utils import setup_logger
|
|
|
|
|
|
class BaseTrainer:
|
|
def __init__(self,
|
|
config,
|
|
model,
|
|
criterion,
|
|
train_loader,
|
|
validate_loader,
|
|
metric_cls,
|
|
post_process=None):
|
|
config['trainer']['output_dir'] = os.path.join(
|
|
str(pathlib.Path(os.path.abspath(__name__)).parent),
|
|
config['trainer']['output_dir'])
|
|
config['name'] = config['name'] + '_' + model.name
|
|
self.save_dir = config['trainer']['output_dir']
|
|
self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')
|
|
|
|
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
|
|
|
self.global_step = 0
|
|
self.start_epoch = 0
|
|
self.config = config
|
|
self.criterion = criterion
|
|
# logger and tensorboard
|
|
self.visualdl_enable = self.config['trainer'].get('visual_dl', False)
|
|
self.epochs = self.config['trainer']['epochs']
|
|
self.log_iter = self.config['trainer']['log_iter']
|
|
if paddle.distributed.get_rank() == 0:
|
|
anyconfig.dump(config, os.path.join(self.save_dir, 'config.yaml'))
|
|
self.logger = setup_logger(os.path.join(self.save_dir, 'train.log'))
|
|
self.logger_info(pformat(self.config))
|
|
|
|
self.model = self.apply_to_static(model)
|
|
|
|
# device
|
|
if paddle.device.cuda.device_count(
|
|
) > 0 and paddle.device.is_compiled_with_cuda():
|
|
self.with_cuda = True
|
|
random.seed(self.config['trainer']['seed'])
|
|
np.random.seed(self.config['trainer']['seed'])
|
|
paddle.seed(self.config['trainer']['seed'])
|
|
else:
|
|
self.with_cuda = False
|
|
self.logger_info('train with and paddle {}'.format(paddle.__version__))
|
|
# metrics
|
|
self.metrics = {
|
|
'recall': 0,
|
|
'precision': 0,
|
|
'hmean': 0,
|
|
'train_loss': float('inf'),
|
|
'best_model_epoch': 0
|
|
}
|
|
|
|
self.train_loader = train_loader
|
|
if validate_loader is not None:
|
|
assert post_process is not None and metric_cls is not None
|
|
self.validate_loader = validate_loader
|
|
self.post_process = post_process
|
|
self.metric_cls = metric_cls
|
|
self.train_loader_len = len(train_loader)
|
|
|
|
if self.validate_loader is not None:
|
|
self.logger_info(
|
|
'train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader'.
|
|
format(
|
|
len(self.train_loader.dataset), self.train_loader_len,
|
|
len(self.validate_loader.dataset),
|
|
len(self.validate_loader)))
|
|
else:
|
|
self.logger_info(
|
|
'train dataset has {} samples,{} in dataloader'.format(
|
|
len(self.train_loader.dataset), self.train_loader_len))
|
|
|
|
self._initialize_scheduler()
|
|
|
|
self._initialize_optimizer()
|
|
|
|
# resume or finetune
|
|
if self.config['trainer']['resume_checkpoint'] != '':
|
|
self._load_checkpoint(
|
|
self.config['trainer']['resume_checkpoint'], resume=True)
|
|
elif self.config['trainer']['finetune_checkpoint'] != '':
|
|
self._load_checkpoint(
|
|
self.config['trainer']['finetune_checkpoint'], resume=False)
|
|
|
|
if self.visualdl_enable and paddle.distributed.get_rank() == 0:
|
|
from visualdl import LogWriter
|
|
self.writer = LogWriter(self.save_dir)
|
|
|
|
# 混合精度训练
|
|
self.amp = self.config.get('amp', None)
|
|
if self.amp == 'None':
|
|
self.amp = None
|
|
if self.amp:
|
|
self.amp['scaler'] = paddle.amp.GradScaler(
|
|
init_loss_scaling=self.amp.get("scale_loss", 1024),
|
|
use_dynamic_loss_scaling=self.amp.get(
|
|
'use_dynamic_loss_scaling', True))
|
|
self.model, self.optimizer = paddle.amp.decorate(
|
|
models=self.model,
|
|
optimizers=self.optimizer,
|
|
level=self.amp.get('amp_level', 'O2'))
|
|
|
|
# 分布式训练
|
|
if paddle.device.cuda.device_count() > 1:
|
|
self.model = paddle.DataParallel(self.model)
|
|
# make inverse Normalize
|
|
self.UN_Normalize = False
|
|
for t in self.config['dataset']['train']['dataset']['args'][
|
|
'transforms']:
|
|
if t['type'] == 'Normalize':
|
|
self.normalize_mean = t['args']['mean']
|
|
self.normalize_std = t['args']['std']
|
|
self.UN_Normalize = True
|
|
|
|
def apply_to_static(self, model):
|
|
support_to_static = self.config['trainer'].get('to_static', False)
|
|
if support_to_static:
|
|
specs = None
|
|
print('static')
|
|
specs = [InputSpec([None, 3, -1, -1])]
|
|
model = to_static(model, input_spec=specs)
|
|
self.logger_info(
|
|
"Successfully to apply @to_static with specs: {}".format(specs))
|
|
return model
|
|
|
|
def train(self):
|
|
"""
|
|
Full training logic
|
|
"""
|
|
for epoch in range(self.start_epoch + 1, self.epochs + 1):
|
|
self.epoch_result = self._train_epoch(epoch)
|
|
self._on_epoch_finish()
|
|
if paddle.distributed.get_rank() == 0 and self.visualdl_enable:
|
|
self.writer.close()
|
|
self._on_train_finish()
|
|
|
|
def _train_epoch(self, epoch):
|
|
"""
|
|
Training logic for an epoch
|
|
|
|
:param epoch: Current epoch number
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def _eval(self, epoch):
|
|
"""
|
|
eval logic for an epoch
|
|
|
|
:param epoch: Current epoch number
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def _on_epoch_finish(self):
|
|
raise NotImplementedError
|
|
|
|
def _on_train_finish(self):
|
|
raise NotImplementedError
|
|
|
|
def _save_checkpoint(self, epoch, file_name):
|
|
"""
|
|
Saving checkpoints
|
|
|
|
:param epoch: current epoch number
|
|
:param log: logging information of the epoch
|
|
:param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar'
|
|
"""
|
|
state_dict = self.model.state_dict()
|
|
state = {
|
|
'epoch': epoch,
|
|
'global_step': self.global_step,
|
|
'state_dict': state_dict,
|
|
'optimizer': self.optimizer.state_dict(),
|
|
'config': self.config,
|
|
'metrics': self.metrics
|
|
}
|
|
filename = os.path.join(self.checkpoint_dir, file_name)
|
|
paddle.save(state, filename)
|
|
|
|
def _load_checkpoint(self, checkpoint_path, resume):
|
|
"""
|
|
Resume from saved checkpoints
|
|
:param checkpoint_path: Checkpoint path to be resumed
|
|
"""
|
|
self.logger_info("Loading checkpoint: {} ...".format(checkpoint_path))
|
|
checkpoint = paddle.load(checkpoint_path)
|
|
self.model.set_state_dict(checkpoint['state_dict'])
|
|
if resume:
|
|
self.global_step = checkpoint['global_step']
|
|
self.start_epoch = checkpoint['epoch']
|
|
self.config['lr_scheduler']['args']['last_epoch'] = self.start_epoch
|
|
# self.scheduler.load_state_dict(checkpoint['scheduler'])
|
|
self.optimizer.set_state_dict(checkpoint['optimizer'])
|
|
if 'metrics' in checkpoint:
|
|
self.metrics = checkpoint['metrics']
|
|
self.logger_info("resume from checkpoint {} (epoch {})".format(
|
|
checkpoint_path, self.start_epoch))
|
|
else:
|
|
self.logger_info("finetune from checkpoint {}".format(
|
|
checkpoint_path))
|
|
|
|
def _initialize(self, name, module, *args, **kwargs):
|
|
module_name = self.config[name]['type']
|
|
module_args = self.config[name].get('args', {})
|
|
assert all([k not in module_args for k in kwargs
|
|
]), 'Overwriting kwargs given in config file is not allowed'
|
|
module_args.update(kwargs)
|
|
return getattr(module, module_name)(*args, **module_args)
|
|
|
|
def _initialize_scheduler(self):
|
|
self.lr_scheduler = self._initialize('lr_scheduler',
|
|
paddle.optimizer.lr)
|
|
|
|
def _initialize_optimizer(self):
|
|
self.optimizer = self._initialize(
|
|
'optimizer',
|
|
paddle.optimizer,
|
|
parameters=self.model.parameters(),
|
|
learning_rate=self.lr_scheduler)
|
|
|
|
def inverse_normalize(self, batch_img):
|
|
if self.UN_Normalize:
|
|
batch_img[:, 0, :, :] = batch_img[:, 0, :, :] * self.normalize_std[
|
|
0] + self.normalize_mean[0]
|
|
batch_img[:, 1, :, :] = batch_img[:, 1, :, :] * self.normalize_std[
|
|
1] + self.normalize_mean[1]
|
|
batch_img[:, 2, :, :] = batch_img[:, 2, :, :] * self.normalize_std[
|
|
2] + self.normalize_mean[2]
|
|
|
|
def logger_info(self, s):
|
|
if paddle.distributed.get_rank() == 0:
|
|
self.logger.info(s)
|