EasyCV/easycv/hooks/oss_sync_hook.py

124 lines
4.7 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import glob
import os
from mmcv.runner import Hook
from mmcv.runner.dist_utils import master_only
from easycv.file import io
from .registry import HOOKS
@HOOKS.register_module
class OSSSyncHook(Hook):
"""
upload log files and checkpoints to oss when training on pai
"""
def __init__(self,
work_dir,
oss_work_dir,
interval=1,
ckpt_filename_tmpl='epoch_{}.pth',
export_ckpt_filename_tmpl='epoch_{}_export.pt',
other_file_list=[],
iter_interval=None):
"""
Args:
work_dir: work_dir in cfg
oss_work_dir: oss directory where to upload local files in work_dir
interval: upload frequency
ckpt_filename_tmpl: checkpoint filename template
other_file_list: other file need to be upload to oss
iter_interval: upload frequency by iter interval, default to be None, means do it with certain assignment
"""
self.work_dir = work_dir
self.oss_work_dir = oss_work_dir
self.interval = interval
self.ckpt_filename_tmpl = ckpt_filename_tmpl
self.export_ckpt_filename_tmpl = export_ckpt_filename_tmpl
self.other_file_list = other_file_list
self.iter_interval = iter_interval
def upload_file(self, runner):
if hasattr(runner, 'file_upload_perepoch'):
up_load_file_list = runner.file_upload_perepoch + self.other_file_list
else:
up_load_file_list = self.other_file_list
up_load_file_list = list(set(up_load_file_list))
epoch = runner.epoch + 1
# try to up load ckpt model
ckpt_fname = self.ckpt_filename_tmpl.format(epoch)
local_ckpt = os.path.join(self.work_dir, ckpt_fname)
oss_ckpt = os.path.join(self.oss_work_dir, ckpt_fname)
if not os.path.exists(local_ckpt):
runner.logger.warning(f'{local_ckpt} does not exists, skip upload')
else:
runner.logger.info(f'upload {local_ckpt} to {oss_ckpt}')
io.safe_copy(local_ckpt, oss_ckpt)
for other_file in up_load_file_list:
local_files = glob.glob(
os.path.join(self.work_dir, other_file), recursive=True)
for local_file in local_files:
rel_path = os.path.relpath(local_file, self.work_dir)
oss_file = os.path.join(self.oss_work_dir, rel_path)
runner.logger.info(f'upload {up_load_file_list}')
io.safe_copy(local_file, oss_file)
# local_tf_logs = os.path.join(self.work_dir, 'tf_logs')
# oss_tf_logs = os.path.join(self.oss_work_dir, 'tf_logs')
# runner.logger.info(f'upload directory {local_tf_logs} to {oss_tf_logs}')
# io.copytree(local_tf_logs, oss_tf_logs)
# we still use oss sdk to upload pth, log, by default iter 1000, which
@master_only
def after_train_iter(self, runner):
# upload checkpoint and tf events
if self.iter_interval is not None:
if not self.every_n_inner_iters(runner, self.iter_interval):
return
self.upload_file(runner)
return
@master_only
def after_train_epoch(self, runner):
# upload checkpoint and tf events
if not self.every_n_epochs(runner, self.interval):
return
self.upload_file(runner)
@master_only
def after_run(self, runner):
# upload final log files
timestamp = runner.timestamp
upload_files = [
'{}.log'.format(timestamp),
'{}.log.json'.format(timestamp),
]
for log_file in upload_files:
local_log = os.path.join(self.work_dir, log_file)
if not os.path.exists(local_log):
runner.logger.warning(
f'{local_log} does not exists, skip upload')
continue
oss_log = os.path.join(self.oss_work_dir, log_file)
runner.logger.info(f'upload {local_log} to {oss_log}')
io.safe_copy(local_log, oss_log)
# try to upload exported model
epoch = runner.epoch
export_ckpt_fname = self.export_ckpt_filename_tmpl.format(epoch)
# upload all export files
export_files = glob.glob(
os.path.join(self.work_dir, '*{}*'.format(export_ckpt_fname)),
recursive=True)
for export_file in export_files:
rel_path = os.path.relpath(export_file, self.work_dir)
target_oss_path = os.path.join(self.oss_work_dir, rel_path)
runner.logger.info(f'upload {export_file} to {target_oss_path}')
io.safe_copy(export_file, target_oss_path)