EasyCV/easycv/hooks/export_hook.py

64 lines
2.0 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from mmcv.runner import Hook
from mmcv.runner.dist_utils import master_only
from easycv.utils.config_tools import validate_export_config
from .registry import HOOKS
@HOOKS.register_module
class ExportHook(Hook):
"""
export model when training on pai
"""
def __init__(
self,
cfg,
ckpt_filename_tmpl='epoch_{}.pth',
export_ckpt_filename_tmpl='epoch_{}_export.pt',
export_after_each_ckpt=False,
):
"""
Args:
cfg: config dict
ckpt_filename_tmpl: checkpoint filename template
"""
self.cfg = validate_export_config(cfg)
self.work_dir = cfg.work_dir
self.ckpt_filename_tmpl = ckpt_filename_tmpl
self.export_ckpt_filename_tmpl = export_ckpt_filename_tmpl
self.export_after_each_ckpt = export_after_each_ckpt or cfg.get(
'export_after_each_ckpt', False)
def export_model(self, runner, epoch):
# epoch = runner.epoch
ckpt_fname = self.ckpt_filename_tmpl.format(epoch)
export_ckpt_fname = self.export_ckpt_filename_tmpl.format(epoch)
local_ckpt = os.path.join(self.work_dir, ckpt_fname)
export_local_ckpt = os.path.join(self.work_dir, export_ckpt_fname)
if not os.path.exists(local_ckpt):
runner.logger.warning(f'{local_ckpt} does not exists, skip export')
else:
runner.logger.info(f'export {local_ckpt} to {export_local_ckpt}')
from easycv.apis.export import export
export(self.cfg, local_ckpt, export_local_ckpt)
@master_only
def after_train_iter(self, runner):
pass
@master_only
def after_train_epoch(self, runner):
# do export after every ckpy is right! should do so!
if self.export_after_each_ckpt:
self.export_model(runner, runner.epoch)
pass
@master_only
def after_run(self, runner):
self.export_model(runner, runner.epoch)