# Copyright (c) Alibaba, Inc. and its affiliates. import os from mmcv.runner import Hook from mmcv.runner.dist_utils import master_only from easycv.utils.config_tools import validate_export_config from .registry import HOOKS @HOOKS.register_module class ExportHook(Hook): """ export model when training on pai """ def __init__( self, cfg, ckpt_filename_tmpl='epoch_{}.pth', export_ckpt_filename_tmpl='epoch_{}_export.pt', export_after_each_ckpt=False, ): """ Args: cfg: config dict ckpt_filename_tmpl: checkpoint filename template """ self.cfg = validate_export_config(cfg) self.work_dir = cfg.work_dir self.ckpt_filename_tmpl = ckpt_filename_tmpl self.export_ckpt_filename_tmpl = export_ckpt_filename_tmpl self.export_after_each_ckpt = export_after_each_ckpt or cfg.get( 'export_after_each_ckpt', False) def export_model(self, runner, epoch): export_ckpt_fname = self.export_ckpt_filename_tmpl.format(epoch) export_local_ckpt = os.path.join(self.work_dir, export_ckpt_fname) runner.logger.info(f'export model to {export_local_ckpt}') from easycv.apis.export import export if hasattr(runner.model, 'module'): model = runner.model.module else: model = runner.model export( self.cfg, ckpt_path='dummy', filename=export_local_ckpt, model=model) @master_only def after_train_iter(self, runner): pass @master_only def after_train_epoch(self, runner): # do export after every ckpy is right! should do so! if self.export_after_each_ckpt: self.export_model(runner, runner.epoch) pass @master_only def after_run(self, runner): self.export_model(runner, runner.epoch)