mirror of https://github.com/alibaba/EasyCV.git
66 lines
1.9 KiB
Python
66 lines
1.9 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import os
|
|
|
|
from mmcv.runner import Hook
|
|
from mmcv.runner.dist_utils import master_only
|
|
|
|
from easycv.utils.config_tools import validate_export_config
|
|
from .registry import HOOKS
|
|
|
|
|
|
@HOOKS.register_module
|
|
class ExportHook(Hook):
|
|
"""
|
|
export model when training on pai
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
cfg,
|
|
ckpt_filename_tmpl='epoch_{}.pth',
|
|
export_ckpt_filename_tmpl='epoch_{}_export.pt',
|
|
export_after_each_ckpt=False,
|
|
):
|
|
"""
|
|
Args:
|
|
cfg: config dict
|
|
ckpt_filename_tmpl: checkpoint filename template
|
|
"""
|
|
self.cfg = validate_export_config(cfg)
|
|
self.work_dir = cfg.work_dir
|
|
self.ckpt_filename_tmpl = ckpt_filename_tmpl
|
|
self.export_ckpt_filename_tmpl = export_ckpt_filename_tmpl
|
|
self.export_after_each_ckpt = export_after_each_ckpt or cfg.get(
|
|
'export_after_each_ckpt', False)
|
|
|
|
def export_model(self, runner, epoch):
|
|
export_ckpt_fname = self.export_ckpt_filename_tmpl.format(epoch)
|
|
export_local_ckpt = os.path.join(self.work_dir, export_ckpt_fname)
|
|
|
|
runner.logger.info(f'export model to {export_local_ckpt}')
|
|
from easycv.apis.export import export
|
|
if hasattr(runner.model, 'module'):
|
|
model = runner.model.module
|
|
else:
|
|
model = runner.model
|
|
export(
|
|
self.cfg,
|
|
ckpt_path='dummy',
|
|
filename=export_local_ckpt,
|
|
model=model)
|
|
|
|
@master_only
|
|
def after_train_iter(self, runner):
|
|
pass
|
|
|
|
@master_only
|
|
def after_train_epoch(self, runner):
|
|
# do export after every ckpy is right! should do so!
|
|
if self.export_after_each_ckpt:
|
|
self.export_model(runner, runner.epoch)
|
|
pass
|
|
|
|
@master_only
|
|
def after_run(self, runner):
|
|
self.export_model(runner, runner.epoch)
|