EasyCV/easycv/hooks/export_hook.py

66 lines
1.9 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from mmcv.runner import Hook
from mmcv.runner.dist_utils import master_only
from easycv.utils.config_tools import validate_export_config
from .registry import HOOKS
@HOOKS.register_module
class ExportHook(Hook):
"""
export model when training on pai
"""
def __init__(
self,
cfg,
ckpt_filename_tmpl='epoch_{}.pth',
export_ckpt_filename_tmpl='epoch_{}_export.pt',
export_after_each_ckpt=False,
):
"""
Args:
cfg: config dict
ckpt_filename_tmpl: checkpoint filename template
"""
self.cfg = validate_export_config(cfg)
self.work_dir = cfg.work_dir
self.ckpt_filename_tmpl = ckpt_filename_tmpl
self.export_ckpt_filename_tmpl = export_ckpt_filename_tmpl
self.export_after_each_ckpt = export_after_each_ckpt or cfg.get(
'export_after_each_ckpt', False)
def export_model(self, runner, epoch):
export_ckpt_fname = self.export_ckpt_filename_tmpl.format(epoch)
export_local_ckpt = os.path.join(self.work_dir, export_ckpt_fname)
runner.logger.info(f'export model to {export_local_ckpt}')
from easycv.apis.export import export
if hasattr(runner.model, 'module'):
model = runner.model.module
else:
model = runner.model
export(
self.cfg,
ckpt_path='dummy',
filename=export_local_ckpt,
model=model)
@master_only
def after_train_iter(self, runner):
pass
@master_only
def after_train_epoch(self, runner):
# do export after every ckpy is right! should do so!
if self.export_after_each_ckpt:
self.export_model(runner, runner.epoch)
pass
@master_only
def after_run(self, runner):
self.export_model(runner, runner.epoch)