# Copyright (c) Alibaba, Inc. and its affiliates.
import os

from mmcv.runner import Hook
from mmcv.runner.dist_utils import master_only

from easycv.utils.config_tools import validate_export_config
from .registry import HOOKS


@HOOKS.register_module
class ExportHook(Hook):
    """
    export model when training on pai
    """

    def __init__(
        self,
        cfg,
        ckpt_filename_tmpl='epoch_{}.pth',
        export_ckpt_filename_tmpl='epoch_{}_export.pt',
        export_after_each_ckpt=False,
    ):
        """
        Args:
            cfg: config dict
            ckpt_filename_tmpl: checkpoint filename template
        """
        self.cfg = validate_export_config(cfg)
        self.work_dir = cfg.work_dir
        self.ckpt_filename_tmpl = ckpt_filename_tmpl
        self.export_ckpt_filename_tmpl = export_ckpt_filename_tmpl
        self.export_after_each_ckpt = export_after_each_ckpt or cfg.get(
            'export_after_each_ckpt', False)

    def export_model(self, runner, epoch):
        # epoch = runner.epoch
        ckpt_fname = self.ckpt_filename_tmpl.format(epoch)
        export_ckpt_fname = self.export_ckpt_filename_tmpl.format(epoch)
        local_ckpt = os.path.join(self.work_dir, ckpt_fname)
        export_local_ckpt = os.path.join(self.work_dir, export_ckpt_fname)

        if not os.path.exists(local_ckpt):
            runner.logger.warning(f'{local_ckpt} does not exists, skip export')
        else:
            runner.logger.info(f'export {local_ckpt} to {export_local_ckpt}')
            from easycv.apis.export import export
            export(self.cfg, local_ckpt, export_local_ckpt)

    @master_only
    def after_train_iter(self, runner):
        pass

    @master_only
    def after_train_epoch(self, runner):
        # do export after every ckpy is right! should do so!
        if self.export_after_each_ckpt:
            self.export_model(runner, runner.epoch)
        pass

    @master_only
    def after_run(self, runner):
        self.export_model(runner, runner.epoch)