mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Save checkpoint again to update best_ckpt of ckpt (#1168)
This commit is contained in:
parent
9d9f2b761e
commit
19aa1eb780
@ -3,6 +3,7 @@ import hashlib
|
|||||||
import logging
|
import logging
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import pickle
|
import pickle
|
||||||
|
from collections import deque
|
||||||
from math import inf
|
from math import inf
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||||
@ -242,6 +243,8 @@ class CheckpointHook(Hook):
|
|||||||
'Find duplicate elements in "published_keys".')
|
'Find duplicate elements in "published_keys".')
|
||||||
self.published_keys = published_keys
|
self.published_keys = published_keys
|
||||||
|
|
||||||
|
self.last_ckpt = None
|
||||||
|
|
||||||
def before_train(self, runner) -> None:
|
def before_train(self, runner) -> None:
|
||||||
"""Finish all operations, related to checkpoint.
|
"""Finish all operations, related to checkpoint.
|
||||||
|
|
||||||
@ -294,6 +297,25 @@ class CheckpointHook(Hook):
|
|||||||
key_indicator] = runner.message_hub.get_info(
|
key_indicator] = runner.message_hub.get_info(
|
||||||
best_ckpt_name)
|
best_ckpt_name)
|
||||||
|
|
||||||
|
if self.max_keep_ckpts > 0:
|
||||||
|
keep_ckpt_ids = []
|
||||||
|
if 'keep_ckpt_ids' in runner.message_hub.runtime_info:
|
||||||
|
keep_ckpt_ids = runner.message_hub.get_info('keep_ckpt_ids')
|
||||||
|
|
||||||
|
while len(keep_ckpt_ids) > self.max_keep_ckpts:
|
||||||
|
step = keep_ckpt_ids.pop(0)
|
||||||
|
if is_main_process():
|
||||||
|
path = self.file_backend.join_path(
|
||||||
|
self.out_dir, self.filename_tmpl.format(step))
|
||||||
|
if self.file_backend.isfile(path):
|
||||||
|
self.file_backend.remove(path)
|
||||||
|
elif self.file_backend.isdir(path):
|
||||||
|
# checkpoints saved by deepspeed are directories
|
||||||
|
self.file_backend.rmtree(path)
|
||||||
|
|
||||||
|
self.keep_ckpt_ids: deque = deque(keep_ckpt_ids,
|
||||||
|
self.max_keep_ckpts)
|
||||||
|
|
||||||
def after_train_epoch(self, runner) -> None:
|
def after_train_epoch(self, runner) -> None:
|
||||||
"""Save the checkpoint and synchronize buffers after each epoch.
|
"""Save the checkpoint and synchronize buffers after each epoch.
|
||||||
|
|
||||||
@ -337,14 +359,13 @@ class CheckpointHook(Hook):
|
|||||||
if self.published_keys is None:
|
if self.published_keys is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.save_last and 'last_ckpt' in runner.message_hub.runtime_info:
|
if self.save_last and self.last_ckpt is not None:
|
||||||
last_ckpt = runner.message_hub.get_info('last_ckpt')
|
self._publish_model(runner, self.last_ckpt)
|
||||||
self._publish_model(runner, last_ckpt)
|
|
||||||
|
|
||||||
if getattr(self, 'best_ckpt_path', None) is not None:
|
if getattr(self, 'best_ckpt_path', None) is not None:
|
||||||
self._publish_model(runner, str(self.best_ckpt_path))
|
self._publish_model(runner, str(self.best_ckpt_path))
|
||||||
if getattr(self, 'best_ckpt_path_dict', None) is not None:
|
if getattr(self, 'best_ckpt_path_dict', None) is not None:
|
||||||
for key, best_ckpt in self.best_ckpt_path_dict.items():
|
for best_ckpt in self.best_ckpt_path_dict.values():
|
||||||
self._publish_model(runner, best_ckpt)
|
self._publish_model(runner, best_ckpt)
|
||||||
|
|
||||||
@master_only
|
@master_only
|
||||||
@ -379,16 +400,35 @@ class CheckpointHook(Hook):
|
|||||||
f'{final_path}.',
|
f'{final_path}.',
|
||||||
logger='current')
|
logger='current')
|
||||||
|
|
||||||
def _save_checkpoint(self, runner) -> None:
|
def _save_checkpoint_with_step(self, runner, step, meta):
|
||||||
"""Save the current checkpoint and delete outdated checkpoint.
|
# remove other checkpoints before save checkpoint to make the
|
||||||
|
# self.keep_ckpt_ids are saved as expected
|
||||||
|
if self.max_keep_ckpts > 0:
|
||||||
|
# _save_checkpoint and _save_best_checkpoint may call this
|
||||||
|
# _save_checkpoint_with_step in one epoch
|
||||||
|
if len(self.keep_ckpt_ids) > 0 and self.keep_ckpt_ids[-1] == step:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if len(self.keep_ckpt_ids) == self.max_keep_ckpts:
|
||||||
|
_step = self.keep_ckpt_ids.popleft()
|
||||||
|
if is_main_process():
|
||||||
|
ckpt_path = self.file_backend.join_path(
|
||||||
|
self.out_dir, self.filename_tmpl.format(_step))
|
||||||
|
|
||||||
Args:
|
if self.file_backend.isfile(ckpt_path):
|
||||||
runner (Runner): The runner of the training process.
|
self.file_backend.remove(ckpt_path)
|
||||||
"""
|
elif self.file_backend.isdir(ckpt_path):
|
||||||
if self.by_epoch:
|
# checkpoints saved by deepspeed are directories
|
||||||
ckpt_filename = self.filename_tmpl.format(runner.epoch + 1)
|
self.file_backend.rmtree(ckpt_path)
|
||||||
else:
|
|
||||||
ckpt_filename = self.filename_tmpl.format(runner.iter + 1)
|
self.keep_ckpt_ids.append(step)
|
||||||
|
runner.message_hub.update_info('keep_ckpt_ids',
|
||||||
|
list(self.keep_ckpt_ids))
|
||||||
|
|
||||||
|
ckpt_filename = self.filename_tmpl.format(step)
|
||||||
|
self.last_ckpt = self.file_backend.join_path(self.out_dir,
|
||||||
|
ckpt_filename)
|
||||||
|
runner.message_hub.update_info('last_ckpt', self.last_ckpt)
|
||||||
|
|
||||||
runner.save_checkpoint(
|
runner.save_checkpoint(
|
||||||
self.out_dir,
|
self.out_dir,
|
||||||
@ -396,6 +436,7 @@ class CheckpointHook(Hook):
|
|||||||
self.file_client_args,
|
self.file_client_args,
|
||||||
save_optimizer=self.save_optimizer,
|
save_optimizer=self.save_optimizer,
|
||||||
save_param_scheduler=self.save_param_scheduler,
|
save_param_scheduler=self.save_param_scheduler,
|
||||||
|
meta=meta,
|
||||||
by_epoch=self.by_epoch,
|
by_epoch=self.by_epoch,
|
||||||
backend_args=self.backend_args,
|
backend_args=self.backend_args,
|
||||||
**self.args)
|
**self.args)
|
||||||
@ -405,31 +446,24 @@ class CheckpointHook(Hook):
|
|||||||
if not is_main_process():
|
if not is_main_process():
|
||||||
return
|
return
|
||||||
|
|
||||||
runner.message_hub.update_info(
|
|
||||||
'last_ckpt',
|
|
||||||
self.file_backend.join_path(self.out_dir, ckpt_filename))
|
|
||||||
|
|
||||||
# remove other checkpoints
|
|
||||||
if self.max_keep_ckpts > 0:
|
|
||||||
if self.by_epoch:
|
|
||||||
current_ckpt = runner.epoch + 1
|
|
||||||
else:
|
|
||||||
current_ckpt = runner.iter + 1
|
|
||||||
redundant_ckpts = range(
|
|
||||||
current_ckpt - self.max_keep_ckpts * self.interval, 0,
|
|
||||||
-self.interval)
|
|
||||||
for _step in redundant_ckpts:
|
|
||||||
ckpt_path = self.file_backend.join_path(
|
|
||||||
self.out_dir, self.filename_tmpl.format(_step))
|
|
||||||
if self.file_backend.isfile(ckpt_path):
|
|
||||||
self.file_backend.remove(ckpt_path)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
save_file = osp.join(runner.work_dir, 'last_checkpoint')
|
save_file = osp.join(runner.work_dir, 'last_checkpoint')
|
||||||
filepath = self.file_backend.join_path(self.out_dir, ckpt_filename)
|
|
||||||
with open(save_file, 'w') as f:
|
with open(save_file, 'w') as f:
|
||||||
f.write(filepath)
|
f.write(self.last_ckpt) # type: ignore
|
||||||
|
|
||||||
|
def _save_checkpoint(self, runner) -> None:
|
||||||
|
"""Save the current checkpoint and delete outdated checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training process.
|
||||||
|
"""
|
||||||
|
if self.by_epoch:
|
||||||
|
step = runner.epoch + 1
|
||||||
|
meta = dict(epoch=step, iter=runner.iter)
|
||||||
|
else:
|
||||||
|
step = runner.iter + 1
|
||||||
|
meta = dict(epoch=runner.epoch, iter=step)
|
||||||
|
|
||||||
|
self._save_checkpoint_with_step(runner, step, meta=meta)
|
||||||
|
|
||||||
def _save_best_checkpoint(self, runner, metrics) -> None:
|
def _save_best_checkpoint(self, runner, metrics) -> None:
|
||||||
"""Save the current checkpoint and delete outdated checkpoint.
|
"""Save the current checkpoint and delete outdated checkpoint.
|
||||||
@ -448,10 +482,13 @@ class CheckpointHook(Hook):
|
|||||||
ckpt_filename = self.filename_tmpl.format(runner.iter)
|
ckpt_filename = self.filename_tmpl.format(runner.iter)
|
||||||
cur_type, cur_time = 'iter', runner.iter
|
cur_type, cur_time = 'iter', runner.iter
|
||||||
|
|
||||||
|
meta = dict(epoch=runner.epoch, iter=runner.iter)
|
||||||
|
|
||||||
# handle auto in self.key_indicators and self.rules before the loop
|
# handle auto in self.key_indicators and self.rules before the loop
|
||||||
if 'auto' in self.key_indicators:
|
if 'auto' in self.key_indicators:
|
||||||
self._init_rule(self.rules, [list(metrics.keys())[0]])
|
self._init_rule(self.rules, [list(metrics.keys())[0]])
|
||||||
|
|
||||||
|
best_ckpt_updated = False
|
||||||
# save best logic
|
# save best logic
|
||||||
# get score from messagehub
|
# get score from messagehub
|
||||||
for key_indicator, rule in zip(self.key_indicators, self.rules):
|
for key_indicator, rule in zip(self.key_indicators, self.rules):
|
||||||
@ -475,13 +512,18 @@ class CheckpointHook(Hook):
|
|||||||
key_score, best_score):
|
key_score, best_score):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
best_ckpt_updated = True
|
||||||
|
|
||||||
best_score = key_score
|
best_score = key_score
|
||||||
runner.message_hub.update_info(best_score_key, best_score)
|
runner.message_hub.update_info(best_score_key, best_score)
|
||||||
|
|
||||||
if best_ckpt_path and \
|
if best_ckpt_path and is_main_process():
|
||||||
self.file_backend.isfile(best_ckpt_path) and \
|
if self.file_backend.isfile(best_ckpt_path):
|
||||||
is_main_process():
|
self.file_backend.remove(best_ckpt_path)
|
||||||
self.file_backend.remove(best_ckpt_path)
|
else:
|
||||||
|
# checkpoints saved by deepspeed are directories
|
||||||
|
self.file_backend.rmtree(best_ckpt_path)
|
||||||
|
|
||||||
runner.logger.info(
|
runner.logger.info(
|
||||||
f'The previous best checkpoint {best_ckpt_path} '
|
f'The previous best checkpoint {best_ckpt_path} '
|
||||||
'is removed')
|
'is removed')
|
||||||
@ -507,12 +549,21 @@ class CheckpointHook(Hook):
|
|||||||
file_client_args=self.file_client_args,
|
file_client_args=self.file_client_args,
|
||||||
save_optimizer=False,
|
save_optimizer=False,
|
||||||
save_param_scheduler=False,
|
save_param_scheduler=False,
|
||||||
|
meta=meta,
|
||||||
by_epoch=False,
|
by_epoch=False,
|
||||||
backend_args=self.backend_args)
|
backend_args=self.backend_args)
|
||||||
runner.logger.info(
|
runner.logger.info(
|
||||||
f'The best checkpoint with {best_score:0.4f} {key_indicator} '
|
f'The best checkpoint with {best_score:0.4f} {key_indicator} '
|
||||||
f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.')
|
f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.')
|
||||||
|
|
||||||
|
# save checkpoint again to update the best_score and best_ckpt stored
|
||||||
|
# in message_hub because the checkpoint saved in `after_train_epoch`
|
||||||
|
# or `after_train_iter` stage only keep the previous best checkpoint
|
||||||
|
# not the current best checkpoint which causes the current best
|
||||||
|
# checkpoint can not be removed when resuming training.
|
||||||
|
if best_ckpt_updated and self.last_ckpt is not None:
|
||||||
|
self._save_checkpoint_with_step(runner, cur_time, meta)
|
||||||
|
|
||||||
def _init_rule(self, rules, key_indicators) -> None:
|
def _init_rule(self, rules, key_indicators) -> None:
|
||||||
"""Initialize rule, key_indicator, comparison_func, and best score. If
|
"""Initialize rule, key_indicator, comparison_func, and best score. If
|
||||||
key_indicator is a list of string and rule is a string, all metric in
|
key_indicator is a list of string and rule is a string, all metric in
|
||||||
|
@ -2090,7 +2090,7 @@ class Runner:
|
|||||||
file_client_args: Optional[dict] = None,
|
file_client_args: Optional[dict] = None,
|
||||||
save_optimizer: bool = True,
|
save_optimizer: bool = True,
|
||||||
save_param_scheduler: bool = True,
|
save_param_scheduler: bool = True,
|
||||||
meta: dict = None,
|
meta: Optional[dict] = None,
|
||||||
by_epoch: bool = True,
|
by_epoch: bool = True,
|
||||||
backend_args: Optional[dict] = None,
|
backend_args: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
@ -2112,8 +2112,8 @@ class Runner:
|
|||||||
to the checkpoint. Defaults to True.
|
to the checkpoint. Defaults to True.
|
||||||
meta (dict, optional): The meta information to be saved in the
|
meta (dict, optional): The meta information to be saved in the
|
||||||
checkpoint. Defaults to None.
|
checkpoint. Defaults to None.
|
||||||
by_epoch (bool): Whether the scheduled momentum is updated by
|
by_epoch (bool): Decide the number of epoch or iteration saved in
|
||||||
epochs. Defaults to True.
|
checkpoint. Defaults to True.
|
||||||
backend_args (dict, optional): Arguments to instantiate the
|
backend_args (dict, optional): Arguments to instantiate the
|
||||||
prefix of uri corresponding backend. Defaults to None.
|
prefix of uri corresponding backend. Defaults to None.
|
||||||
New in v0.2.0.
|
New in v0.2.0.
|
||||||
@ -2129,9 +2129,11 @@ class Runner:
|
|||||||
# `self.call_hook('after_train_epoch)` but `save_checkpoint` is
|
# `self.call_hook('after_train_epoch)` but `save_checkpoint` is
|
||||||
# called by `after_train_epoch`` method of `CheckpointHook` so
|
# called by `after_train_epoch`` method of `CheckpointHook` so
|
||||||
# `epoch` should be `self.epoch + 1`
|
# `epoch` should be `self.epoch + 1`
|
||||||
meta.update(epoch=self.epoch + 1, iter=self.iter)
|
meta.setdefault('epoch', self.epoch + 1)
|
||||||
|
meta.setdefault('iter', self.iter)
|
||||||
else:
|
else:
|
||||||
meta.update(epoch=self.epoch, iter=self.iter + 1)
|
meta.setdefault('epoch', self.epoch)
|
||||||
|
meta.setdefault('iter', self.iter + 1)
|
||||||
|
|
||||||
if file_client_args is not None:
|
if file_client_args is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
@ -433,12 +433,13 @@ class TestCheckpointHook(RunnerTestCase):
|
|||||||
setattr(common_cfg.train_cfg, f'max_{training_type}s', 11)
|
setattr(common_cfg.train_cfg, f'max_{training_type}s', 11)
|
||||||
checkpoint_cfg = dict(
|
checkpoint_cfg = dict(
|
||||||
type='CheckpointHook',
|
type='CheckpointHook',
|
||||||
interval=2,
|
interval=1,
|
||||||
by_epoch=training_type == 'epoch')
|
by_epoch=training_type == 'epoch')
|
||||||
common_cfg.default_hooks = dict(checkpoint=checkpoint_cfg)
|
common_cfg.default_hooks = dict(checkpoint=checkpoint_cfg)
|
||||||
|
|
||||||
# Test interval in epoch based training
|
# Test interval in epoch based training
|
||||||
cfg = copy.deepcopy(common_cfg)
|
cfg = copy.deepcopy(common_cfg)
|
||||||
|
cfg.default_hooks.checkpoint.interval = 2
|
||||||
runner = self.build_runner(cfg)
|
runner = self.build_runner(cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
|
|
||||||
@ -502,13 +503,11 @@ class TestCheckpointHook(RunnerTestCase):
|
|||||||
|
|
||||||
self.clear_work_dir()
|
self.clear_work_dir()
|
||||||
|
|
||||||
# Test max_keep_ckpts
|
# Test max_keep_ckpts=1
|
||||||
cfg = copy.deepcopy(common_cfg)
|
cfg = copy.deepcopy(common_cfg)
|
||||||
cfg.default_hooks.checkpoint.interval = 1
|
|
||||||
cfg.default_hooks.checkpoint.max_keep_ckpts = 1
|
cfg.default_hooks.checkpoint.max_keep_ckpts = 1
|
||||||
runner = self.build_runner(cfg)
|
runner = self.build_runner(cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
print(os.listdir(cfg.work_dir))
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
|
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
|
||||||
|
|
||||||
@ -518,6 +517,45 @@ class TestCheckpointHook(RunnerTestCase):
|
|||||||
|
|
||||||
self.clear_work_dir()
|
self.clear_work_dir()
|
||||||
|
|
||||||
|
# Test max_keep_ckpts=3
|
||||||
|
cfg = copy.deepcopy(common_cfg)
|
||||||
|
cfg.default_hooks.checkpoint.max_keep_ckpts = 3
|
||||||
|
runner = self.build_runner(cfg)
|
||||||
|
runner.train()
|
||||||
|
self.assertTrue(
|
||||||
|
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_9.pth')))
|
||||||
|
self.assertTrue(
|
||||||
|
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth')))
|
||||||
|
self.assertTrue(
|
||||||
|
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
|
||||||
|
|
||||||
|
for i in range(9):
|
||||||
|
self.assertFalse(
|
||||||
|
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
||||||
|
|
||||||
|
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
|
||||||
|
self.assertEqual(ckpt['message_hub']['runtime_info']['keep_ckpt_ids'],
|
||||||
|
[9, 10, 11])
|
||||||
|
|
||||||
|
# Test max_keep_ckpts when resuming traing
|
||||||
|
cfg = copy.deepcopy(common_cfg)
|
||||||
|
setattr(cfg.train_cfg, f'max_{training_type}s', 12)
|
||||||
|
cfg.default_hooks.checkpoint.max_keep_ckpts = 2
|
||||||
|
cfg.load_from = osp.join(cfg.work_dir, f'{training_type}_11.pth')
|
||||||
|
cfg.resume = True
|
||||||
|
runner = self.build_runner(cfg)
|
||||||
|
runner.train()
|
||||||
|
self.assertFalse(
|
||||||
|
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_9.pth')))
|
||||||
|
self.assertFalse(
|
||||||
|
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth')))
|
||||||
|
self.assertTrue(
|
||||||
|
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
|
||||||
|
self.assertTrue(
|
||||||
|
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_12.pth')))
|
||||||
|
|
||||||
|
self.clear_work_dir()
|
||||||
|
|
||||||
# Test filename_tmpl
|
# Test filename_tmpl
|
||||||
cfg = copy.deepcopy(common_cfg)
|
cfg = copy.deepcopy(common_cfg)
|
||||||
cfg.default_hooks.checkpoint.filename_tmpl = 'test_{}.pth'
|
cfg.default_hooks.checkpoint.filename_tmpl = 'test_{}.pth'
|
||||||
@ -529,18 +567,61 @@ class TestCheckpointHook(RunnerTestCase):
|
|||||||
|
|
||||||
# Test save_best
|
# Test save_best
|
||||||
cfg = copy.deepcopy(common_cfg)
|
cfg = copy.deepcopy(common_cfg)
|
||||||
cfg.default_hooks.checkpoint.interval = 1
|
|
||||||
cfg.default_hooks.checkpoint.save_best = 'test/acc'
|
cfg.default_hooks.checkpoint.save_best = 'test/acc'
|
||||||
cfg.val_evaluator = dict(type='TriangleMetric', length=11)
|
cfg.val_evaluator = dict(type='TriangleMetric', length=11)
|
||||||
cfg.train_cfg.val_interval = 1
|
cfg.train_cfg.val_interval = 1
|
||||||
runner = self.build_runner(cfg)
|
runner = self.build_runner(cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
best_ckpt = osp.join(cfg.work_dir,
|
best_ckpt_path = osp.join(cfg.work_dir,
|
||||||
f'best_test_acc_{training_type}_5.pth')
|
f'best_test_acc_{training_type}_5.pth')
|
||||||
self.assertTrue(osp.isfile(best_ckpt))
|
best_ckpt = torch.load(best_ckpt_path)
|
||||||
|
|
||||||
|
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth'))
|
||||||
|
self.assertEqual(best_ckpt_path,
|
||||||
|
ckpt['message_hub']['runtime_info']['best_ckpt'])
|
||||||
|
|
||||||
|
if training_type == 'epoch':
|
||||||
|
self.assertEqual(ckpt['meta']['epoch'], 5)
|
||||||
|
self.assertEqual(ckpt['meta']['iter'], 20)
|
||||||
|
self.assertEqual(best_ckpt['meta']['epoch'], 5)
|
||||||
|
self.assertEqual(best_ckpt['meta']['iter'], 20)
|
||||||
|
else:
|
||||||
|
self.assertEqual(ckpt['meta']['epoch'], 0)
|
||||||
|
self.assertEqual(ckpt['meta']['iter'], 5)
|
||||||
|
self.assertEqual(best_ckpt['meta']['epoch'], 0)
|
||||||
|
self.assertEqual(best_ckpt['meta']['iter'], 5)
|
||||||
|
|
||||||
self.clear_work_dir()
|
self.clear_work_dir()
|
||||||
|
|
||||||
|
# Test save_best with interval=2
|
||||||
|
cfg = copy.deepcopy(common_cfg)
|
||||||
|
cfg.default_hooks.checkpoint.save_best = 'test/acc'
|
||||||
|
cfg.default_hooks.checkpoint.interval = 2
|
||||||
|
cfg.val_evaluator = dict(type='TriangleMetric', length=11)
|
||||||
|
cfg.train_cfg.val_interval = 1
|
||||||
|
runner = self.build_runner(cfg)
|
||||||
|
runner.train()
|
||||||
|
best_ckpt_path = osp.join(cfg.work_dir,
|
||||||
|
f'best_test_acc_{training_type}_5.pth')
|
||||||
|
best_ckpt = torch.load(best_ckpt_path)
|
||||||
|
|
||||||
|
# if the current ckpt is the best, the interval will be ignored the
|
||||||
|
# the ckpt will also be saved
|
||||||
|
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth'))
|
||||||
|
self.assertEqual(best_ckpt_path,
|
||||||
|
ckpt['message_hub']['runtime_info']['best_ckpt'])
|
||||||
|
|
||||||
|
if training_type == 'epoch':
|
||||||
|
self.assertEqual(ckpt['meta']['epoch'], 5)
|
||||||
|
self.assertEqual(ckpt['meta']['iter'], 20)
|
||||||
|
self.assertEqual(best_ckpt['meta']['epoch'], 5)
|
||||||
|
self.assertEqual(best_ckpt['meta']['iter'], 20)
|
||||||
|
else:
|
||||||
|
self.assertEqual(ckpt['meta']['epoch'], 0)
|
||||||
|
self.assertEqual(ckpt['meta']['iter'], 5)
|
||||||
|
self.assertEqual(best_ckpt['meta']['epoch'], 0)
|
||||||
|
self.assertEqual(best_ckpt['meta']['iter'], 5)
|
||||||
|
|
||||||
# test save published keys
|
# test save published keys
|
||||||
cfg = copy.deepcopy(common_cfg)
|
cfg = copy.deepcopy(common_cfg)
|
||||||
cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict']
|
cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict']
|
||||||
|
Loading…
x
Reference in New Issue
Block a user