mirror of https://github.com/alibaba/EasyCV.git
124 lines
4.7 KiB
Python
124 lines
4.7 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import glob
|
|
import os
|
|
|
|
from mmcv.runner import Hook
|
|
from mmcv.runner.dist_utils import master_only
|
|
|
|
from easycv.file import io
|
|
from .registry import HOOKS
|
|
|
|
|
|
@HOOKS.register_module
|
|
class OSSSyncHook(Hook):
|
|
"""
|
|
upload log files and checkpoints to oss when training on pai
|
|
"""
|
|
|
|
def __init__(self,
|
|
work_dir,
|
|
oss_work_dir,
|
|
interval=1,
|
|
ckpt_filename_tmpl='epoch_{}.pth',
|
|
export_ckpt_filename_tmpl='epoch_{}_export.pt',
|
|
other_file_list=[],
|
|
iter_interval=None):
|
|
"""
|
|
Args:
|
|
work_dir: work_dir in cfg
|
|
oss_work_dir: oss directory where to upload local files in work_dir
|
|
interval: upload frequency
|
|
ckpt_filename_tmpl: checkpoint filename template
|
|
other_file_list: other file need to be upload to oss
|
|
iter_interval: upload frequency by iter interval, default to be None, means do it with certain assignment
|
|
"""
|
|
self.work_dir = work_dir
|
|
self.oss_work_dir = oss_work_dir
|
|
self.interval = interval
|
|
self.ckpt_filename_tmpl = ckpt_filename_tmpl
|
|
self.export_ckpt_filename_tmpl = export_ckpt_filename_tmpl
|
|
self.other_file_list = other_file_list
|
|
self.iter_interval = iter_interval
|
|
|
|
def upload_file(self, runner):
|
|
if hasattr(runner, 'file_upload_perepoch'):
|
|
up_load_file_list = runner.file_upload_perepoch + self.other_file_list
|
|
else:
|
|
up_load_file_list = self.other_file_list
|
|
|
|
up_load_file_list = list(set(up_load_file_list))
|
|
epoch = runner.epoch + 1
|
|
|
|
# try to up load ckpt model
|
|
ckpt_fname = self.ckpt_filename_tmpl.format(epoch)
|
|
local_ckpt = os.path.join(self.work_dir, ckpt_fname)
|
|
oss_ckpt = os.path.join(self.oss_work_dir, ckpt_fname)
|
|
if not os.path.exists(local_ckpt):
|
|
runner.logger.warning(f'{local_ckpt} does not exists, skip upload')
|
|
else:
|
|
runner.logger.info(f'upload {local_ckpt} to {oss_ckpt}')
|
|
io.safe_copy(local_ckpt, oss_ckpt)
|
|
|
|
for other_file in up_load_file_list:
|
|
local_files = glob.glob(
|
|
os.path.join(self.work_dir, other_file), recursive=True)
|
|
for local_file in local_files:
|
|
rel_path = os.path.relpath(local_file, self.work_dir)
|
|
oss_file = os.path.join(self.oss_work_dir, rel_path)
|
|
runner.logger.info(f'upload {up_load_file_list}')
|
|
io.safe_copy(local_file, oss_file)
|
|
|
|
# local_tf_logs = os.path.join(self.work_dir, 'tf_logs')
|
|
# oss_tf_logs = os.path.join(self.oss_work_dir, 'tf_logs')
|
|
# runner.logger.info(f'upload directory {local_tf_logs} to {oss_tf_logs}')
|
|
# io.copytree(local_tf_logs, oss_tf_logs)
|
|
|
|
# we still use oss sdk to upload pth, log, by default iter 1000, which
|
|
@master_only
|
|
def after_train_iter(self, runner):
|
|
# upload checkpoint and tf events
|
|
if self.iter_interval is not None:
|
|
if not self.every_n_inner_iters(runner, self.iter_interval):
|
|
return
|
|
self.upload_file(runner)
|
|
return
|
|
|
|
@master_only
|
|
def after_train_epoch(self, runner):
|
|
# upload checkpoint and tf events
|
|
if not self.every_n_epochs(runner, self.interval):
|
|
return
|
|
self.upload_file(runner)
|
|
|
|
@master_only
|
|
def after_run(self, runner):
|
|
# upload final log files
|
|
timestamp = runner.timestamp
|
|
upload_files = [
|
|
'{}.log'.format(timestamp),
|
|
'{}.log.json'.format(timestamp),
|
|
]
|
|
for log_file in upload_files:
|
|
local_log = os.path.join(self.work_dir, log_file)
|
|
if not os.path.exists(local_log):
|
|
runner.logger.warning(
|
|
f'{local_log} does not exists, skip upload')
|
|
continue
|
|
oss_log = os.path.join(self.oss_work_dir, log_file)
|
|
runner.logger.info(f'upload {local_log} to {oss_log}')
|
|
io.safe_copy(local_log, oss_log)
|
|
|
|
# try to upload exported model
|
|
epoch = runner.epoch
|
|
export_ckpt_fname = self.export_ckpt_filename_tmpl.format(epoch)
|
|
|
|
# upload all export files
|
|
export_files = glob.glob(
|
|
os.path.join(self.work_dir, '*{}*'.format(export_ckpt_fname)),
|
|
recursive=True)
|
|
for export_file in export_files:
|
|
rel_path = os.path.relpath(export_file, self.work_dir)
|
|
target_oss_path = os.path.join(self.oss_work_dir, rel_path)
|
|
runner.logger.info(f'upload {export_file} to {target_oss_path}')
|
|
io.safe_copy(export_file, target_oss_path)
|