mirror of https://github.com/alibaba/EasyCV.git
93 lines
3.8 KiB
Python
93 lines
3.8 KiB
Python
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||
|
from mmcv.runner import Hook
|
||
|
from mmcv.runner.dist_utils import master_only
|
||
|
|
||
|
from easycv.utils import get_root_logger
|
||
|
from .registry import HOOKS
|
||
|
|
||
|
|
||
|
@HOOKS.register_module()
|
||
|
class BestCkptSaverHook(Hook):
|
||
|
"""Save checkpoints periodically.
|
||
|
|
||
|
Args:
|
||
|
by_epoch (bool): Saving checkpoints by epoch or by iteration.
|
||
|
Default: True.
|
||
|
save_optimizer (bool): Whether to save optimizer state_dict in the
|
||
|
checkpoint. It is usually used for resuming experiments.
|
||
|
Default: True.
|
||
|
best_metric_name (List(str)) : metric name to save best, such as "neck_top1"...
|
||
|
Default: [], do not save anything
|
||
|
best_metric_type (List(str)) : metric type to define best, should be "max", "min"
|
||
|
if len(best_metric_type) <= len(best_metric_type), use "max" to append.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
by_epoch=True,
|
||
|
save_optimizer=True,
|
||
|
best_metric_name=[],
|
||
|
best_metric_type=[],
|
||
|
**kwargs):
|
||
|
self.by_epoch = by_epoch
|
||
|
self.save_optimizer = save_optimizer
|
||
|
self.out_dir = None
|
||
|
self.best_metric_name = best_metric_name
|
||
|
self.best_metric_type = best_metric_type
|
||
|
|
||
|
self.logger = get_root_logger()
|
||
|
if len(self.best_metric_name) == 0:
|
||
|
self.logger.warning(
|
||
|
'BestCkptSaverHook should assign best_metric_name, otherwise ,no best ckpt should will save'
|
||
|
)
|
||
|
if len(self.best_metric_name) != len(self.best_metric_type):
|
||
|
self.logger.warning(
|
||
|
f'BestCkptSaverHook should have same length of best_metric_name and best_metric_type ({len(self.best_metric_name)} vs {len(self.best_metric_type)})'
|
||
|
)
|
||
|
self.logger.warning(
|
||
|
'BestCkptSaverHook will use max as default metric type')
|
||
|
|
||
|
while len(self.best_metric_type) < len(self.best_metric_name):
|
||
|
self.best_metric_type.append('max')
|
||
|
|
||
|
self.args = kwargs
|
||
|
|
||
|
@master_only
|
||
|
def before_run(self, runner):
|
||
|
if not hasattr(runner, 'file_upload_perepoch'):
|
||
|
runner.file_upload_perepoch = []
|
||
|
|
||
|
if not self.out_dir:
|
||
|
self.out_dir = runner.work_dir
|
||
|
|
||
|
self.after_train_epoch(runner)
|
||
|
|
||
|
@master_only
|
||
|
def after_train_epoch(self, runner):
|
||
|
|
||
|
if len(self.best_metric_name) > 0 and hasattr(runner, 'eval_res'):
|
||
|
self.logger.info(f'SaveBest metric_name: {self.best_metric_name}')
|
||
|
for k in runner.eval_res.keys():
|
||
|
result_list = runner.eval_res[k]
|
||
|
if len(result_list) > 0:
|
||
|
keys = list(result_list[0].keys())
|
||
|
keys.remove('runner_epoch')
|
||
|
for key in keys:
|
||
|
if key in self.best_metric_name:
|
||
|
metric_type = eval(self.best_metric_type[
|
||
|
self.best_metric_name.index(key)])
|
||
|
maxr = metric_type(
|
||
|
result_list, key=lambda x: x[key])
|
||
|
if maxr['runner_epoch'] == runner.epoch:
|
||
|
runner.file_upload_perepoch.append(
|
||
|
'%s_best.pth' % (key))
|
||
|
runner.file_upload_perepoch = list(
|
||
|
set(runner.file_upload_perepoch))
|
||
|
meta = {'epoch': runner.epoch - 1, **maxr}
|
||
|
runner.save_checkpoint(
|
||
|
self.out_dir,
|
||
|
filename_tmpl='%s_best.pth' % (key),
|
||
|
save_optimizer=self.save_optimizer,
|
||
|
meta=meta)
|
||
|
|
||
|
self.logger.info('End SaveBest metric')
|