diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index c9aa1a35..2ccd31d9 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp +import re import warnings from collections import OrderedDict from math import inf @@ -299,8 +300,8 @@ class CheckpointHook(Hook): def _get_metric_score(self, metrics, key_indicator): eval_res = OrderedDict() if metrics is not None: - eval_res.update(metrics) - + for key, value in metrics.items(): + eval_res[key.partition('/')[-1]] = value if len(eval_res) == 0: warnings.warn( 'Since `eval_res` is an empty dict, the behavior to save ' @@ -417,6 +418,7 @@ class CheckpointHook(Hook): 'is removed') best_ckpt_name = f'best_{key_indicator}_{ckpt_filename}' + best_ckpt_name = re.sub(r'(\W)', '_', best_ckpt_name) if len(self.key_indicators) == 1: self.best_ckpt_path = self.file_client.join_path( # type: ignore # noqa: E501 self.out_dir, best_ckpt_name) diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index 8fbb1a56..bfe8bd65 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -1,56 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os +import copy import os.path as osp -from unittest.mock import Mock -import pytest import torch -import torch.nn as nn -from torch.utils.data import Dataset +from parameterized import parameterized from mmengine.evaluator import BaseMetric from mmengine.fileio import FileClient, LocalBackend from mmengine.hooks import CheckpointHook from mmengine.logging import MessageHub -from mmengine.model import BaseModel -from mmengine.optim import OptimWrapper -from mmengine.runner import Runner - - -class ToyModel(BaseModel): - - def __init__(self): - super().__init__() - self.linear = nn.Linear(2, 1) - - def forward(self, inputs, data_sample, mode='tensor'): - labels = torch.stack(data_sample) - inputs = torch.stack(inputs) - outputs = self.linear(inputs) - if mode == 'tensor': - return outputs - elif mode == 'loss': - loss = (labels - outputs).sum() - outputs = dict(loss=loss) - return outputs - else: - return outputs - - -class DummyDataset(Dataset): - METAINFO = dict() # type: ignore - data = torch.randn(12, 2) - label = torch.ones(12) - - @property - def metainfo(self): - return self.METAINFO - - def __len__(self): - return self.data.size(0) - - def __getitem__(self, index): - return dict(inputs=self.data[index], data_sample=self.label[index]) +from mmengine.registry import METRICS +from mmengine.testing import RunnerTestCase class TriangleMetric(BaseMetric): @@ -72,428 +32,435 @@ class TriangleMetric(BaseMetric): return dict(acc=acc) -class TestCheckpointHook: +class TestCheckpointHook(RunnerTestCase): - def test_init(self, tmp_path): + def setUp(self): + super().setUp() + METRICS.register_module(module=TriangleMetric, force=True) + + def tearDown(self): + return METRICS.module_dict.clear() + + def test_init(self): # Test file_client_args and backend_args - with pytest.warns( + with self.assertWarnsRegex( DeprecationWarning, - match='"file_client_args" will be deprecated in future'): + '"file_client_args" will be deprecated in future'): CheckpointHook(file_client_args={'backend': 'disk'}) - with pytest.raises( + with self.assertRaisesRegex( ValueError, - match='"file_client_args" and "backend_args" cannot be set ' + '"file_client_args" and "backend_args" cannot be set ' 'at the same time'): CheckpointHook( file_client_args={'backend': 'disk'}, backend_args={'backend': 'local'}) - def test_before_train(self, tmp_path): - runner = Mock() - work_dir = str(tmp_path) - runner.work_dir = work_dir + # Test save best + CheckpointHook(save_best='acc') + CheckpointHook(save_best=['acc']) + + with self.assertRaisesRegex(AssertionError, '"save_best" should be'): + CheckpointHook(save_best=dict(acc='acc')) + + # error when 'auto' in `save_best` list + with self.assertRaisesRegex(AssertionError, 'Only support one'): + CheckpointHook(interval=2, save_best=['auto', 'acc']) + + # Test rules + + CheckpointHook(save_best=['acc', 'mAcc'], rule='greater') + + with self.assertRaisesRegex(AssertionError, '"rule" should be a str'): + CheckpointHook(save_best=['acc'], rule=1) + + with self.assertRaisesRegex(AssertionError, + 'Number of "rule" must be'): + CheckpointHook(save_best=['acc'], rule=['greater', 'loss']) + + # Test greater_keys + hook = CheckpointHook(greater_keys='acc') + self.assertEqual(hook.greater_keys, ('acc', )) + + hook = CheckpointHook(greater_keys=['acc']) + self.assertEqual(hook.greater_keys, ['acc']) + + hook = CheckpointHook( + interval=2, by_epoch=False, save_best=['acc', 'mIoU']) + self.assertEqual(hook.greater_keys, ('acc', 'mIoU')) + self.assertEqual(hook.rule, ('greater', 'greater')) + + # Test less keys + hook = CheckpointHook(less_keys='loss_cls') + self.assertEqual(hook.less_keys, ('loss_cls', )) + + hook = CheckpointHook(less_keys=['loss_cls']) + self.assertEqual(hook.less_keys, ['loss_cls']) + + def test_before_train(self): + cfg = copy.deepcopy(self.epoch_based_cfg) + runner = self.build_runner(cfg) # file_client_args is None checkpoint_hook = CheckpointHook() checkpoint_hook.before_train(runner) - assert isinstance(checkpoint_hook.file_client, FileClient) - assert isinstance(checkpoint_hook.file_backend, LocalBackend) + self.assertIsInstance(checkpoint_hook.file_client, FileClient) + self.assertIsInstance(checkpoint_hook.file_backend, LocalBackend) # file_client_args is not None checkpoint_hook = CheckpointHook(file_client_args={'backend': 'disk'}) checkpoint_hook.before_train(runner) - assert isinstance(checkpoint_hook.file_client, FileClient) + self.assertIsInstance(checkpoint_hook.file_client, FileClient) # file_backend is the alias of file_client - assert checkpoint_hook.file_backend is checkpoint_hook.file_client + self.assertIs(checkpoint_hook.file_backend, + checkpoint_hook.file_client) # the out_dir of the checkpoint hook is None checkpoint_hook = CheckpointHook(interval=1, by_epoch=True) checkpoint_hook.before_train(runner) - assert checkpoint_hook.out_dir == runner.work_dir + self.assertEqual(checkpoint_hook.out_dir, runner.work_dir) # the out_dir of the checkpoint hook is not None checkpoint_hook = CheckpointHook( interval=1, by_epoch=True, out_dir='test_dir') checkpoint_hook.before_train(runner) - assert checkpoint_hook.out_dir == osp.join( - 'test_dir', osp.join(osp.basename(work_dir))) + self.assertEqual( + checkpoint_hook.out_dir, + osp.join('test_dir', osp.join(osp.basename(cfg.work_dir)))) - runner.message_hub = MessageHub.get_instance('test_before_train') - # no 'best_ckpt_path' in runtime_info + # If `save_best` is a list of string, the path to save the best + # checkpoint will be defined in attribute `best_ckpt_path_dict`. checkpoint_hook = CheckpointHook(interval=1, save_best=['acc', 'mIoU']) checkpoint_hook.before_train(runner) - assert checkpoint_hook.best_ckpt_path_dict == dict(acc=None, mIoU=None) - assert not hasattr(checkpoint_hook, 'best_ckpt_path') + self.assertEqual(checkpoint_hook.best_ckpt_path_dict, + dict(acc=None, mIoU=None)) + self.assertFalse(hasattr(checkpoint_hook, 'best_ckpt_path')) - # only one 'best_ckpt_path' in runtime_info + # Resume 'best_ckpt_path' from message_hub runner.message_hub.update_info('best_ckpt_acc', 'best_acc') checkpoint_hook.before_train(runner) - assert checkpoint_hook.best_ckpt_path_dict == dict( - acc='best_acc', mIoU=None) + self.assertEqual(checkpoint_hook.best_ckpt_path_dict, + dict(acc='best_acc', mIoU=None)) - # no 'best_ckpt_path' in runtime_info + # If `save_best` is a string, the path to save best ckpt will be + # defined in attribute `best_ckpt_path` checkpoint_hook = CheckpointHook(interval=1, save_best='acc') checkpoint_hook.before_train(runner) - assert checkpoint_hook.best_ckpt_path is None - assert not hasattr(checkpoint_hook, 'best_ckpt_path_dict') + self.assertIsNone(checkpoint_hook.best_ckpt_path) + self.assertFalse(hasattr(checkpoint_hook, 'best_ckpt_path_dict')) - # 'best_ckpt_path' in runtime_info + # Resume `best_ckpt` path from message_hub runner.message_hub.update_info('best_ckpt', 'best_ckpt') checkpoint_hook.before_train(runner) - assert checkpoint_hook.best_ckpt_path == 'best_ckpt' + self.assertEqual(checkpoint_hook.best_ckpt_path, 'best_ckpt') - def test_after_val_epoch(self, tmp_path): - runner = Mock() - runner.work_dir = tmp_path - runner.epoch = 9 - runner.model = Mock() - runner.message_hub = MessageHub.get_instance('test_after_val_epoch') - - with pytest.raises(ValueError): - # key_indicator must be valid when rule_map is None - CheckpointHook(interval=2, by_epoch=True, save_best='unsupport') - - with pytest.raises(KeyError): - # rule must be in keys of rule_map - CheckpointHook( - interval=2, by_epoch=True, save_best='auto', rule='unsupport') - - # if eval_res is an empty dict, print a warning information - with pytest.warns(UserWarning) as record_warnings: - eval_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='auto') - eval_hook._get_metric_score(None, None) - # Since there will be many warnings thrown, we just need to check - # if the expected exceptions are thrown - expected_message = ( - 'Since `eval_res` is an empty dict, the behavior to ' - 'save the best checkpoint will be skipped in this ' - 'evaluation.') - for warning in record_warnings: - if str(warning.message) == expected_message: - break - else: - assert False - - # test error when number of rules and metrics are not same - with pytest.raises(AssertionError) as assert_error: - CheckpointHook( - interval=1, - save_best=['mIoU', 'acc'], - rule=['greater', 'greater', 'less'], - by_epoch=True) - error_message = ('Number of "rule" must be 1 or the same as number of ' - '"save_best", but got 3.') - assert error_message in str(assert_error.value) + def test_after_val_epoch(self): + cfg = copy.deepcopy(self.epoch_based_cfg) + runner = self.build_runner(cfg) + runner.train_loop._epoch = 9 # if save_best is None,no best_ckpt meta should be stored - eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best=None) - eval_hook.before_train(runner) - eval_hook.after_val_epoch(runner, None) - assert 'best_score' not in runner.message_hub.runtime_info - assert 'best_ckpt' not in runner.message_hub.runtime_info + ckpt_hook = CheckpointHook(interval=2, by_epoch=True, save_best=None) + ckpt_hook.before_train(runner) + ckpt_hook.after_val_epoch(runner, None) + self.assertNotIn('best_score', runner.message_hub.runtime_info) + self.assertNotIn('best_ckpt', runner.message_hub.runtime_info) # when `save_best` is set to `auto`, first metric will be used. metrics = {'acc': 0.5, 'map': 0.3} - eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='auto') - eval_hook.before_train(runner) - eval_hook.after_val_epoch(runner, metrics) + ckpt_hook = CheckpointHook(interval=2, by_epoch=True, save_best='auto') + ckpt_hook.before_train(runner) + ckpt_hook.after_val_epoch(runner, metrics) best_ckpt_name = 'best_acc_epoch_9.pth' - best_ckpt_path = eval_hook.file_client.join_path( - eval_hook.out_dir, best_ckpt_name) - assert eval_hook.key_indicators == ['acc'] - assert eval_hook.rules == ['greater'] - assert 'best_score' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('best_score') == 0.5 - assert 'best_ckpt' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('best_ckpt') == best_ckpt_path + best_ckpt_path = ckpt_hook.file_client.join_path( + ckpt_hook.out_dir, best_ckpt_name) + self.assertEqual(ckpt_hook.key_indicators, ['acc']) + self.assertEqual(ckpt_hook.rules, ['greater']) + self.assertEqual(runner.message_hub.get_info('best_score'), 0.5) + self.assertEqual( + runner.message_hub.get_info('best_ckpt'), best_ckpt_path) # # when `save_best` is set to `acc`, it should update greater value - eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='acc') - eval_hook.before_train(runner) + ckpt_hook = CheckpointHook(interval=2, by_epoch=True, save_best='acc') + ckpt_hook.before_train(runner) metrics['acc'] = 0.8 - eval_hook.after_val_epoch(runner, metrics) - assert 'best_score' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('best_score') == 0.8 + ckpt_hook.after_val_epoch(runner, metrics) + self.assertEqual(runner.message_hub.get_info('best_score'), 0.8) # # when `save_best` is set to `loss`, it should update less value - eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='loss') - eval_hook.before_train(runner) + ckpt_hook = CheckpointHook(interval=2, by_epoch=True, save_best='loss') + ckpt_hook.before_train(runner) metrics['loss'] = 0.8 - eval_hook.after_val_epoch(runner, metrics) + ckpt_hook.after_val_epoch(runner, metrics) metrics['loss'] = 0.5 - eval_hook.after_val_epoch(runner, metrics) - assert 'best_score' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('best_score') == 0.5 + ckpt_hook.after_val_epoch(runner, metrics) + self.assertEqual(runner.message_hub.get_info('best_score'), 0.5) # when `rule` is set to `less`,then it should update less value # no matter what `save_best` is - eval_hook = CheckpointHook( + ckpt_hook = CheckpointHook( interval=2, by_epoch=True, save_best='acc', rule='less') - eval_hook.before_train(runner) + ckpt_hook.before_train(runner) metrics['acc'] = 0.3 - eval_hook.after_val_epoch(runner, metrics) - assert 'best_score' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('best_score') == 0.3 + ckpt_hook.after_val_epoch(runner, metrics) + self.assertEqual(runner.message_hub.get_info('best_score'), 0.3) # # when `rule` is set to `greater`,then it should update greater value # # no matter what `save_best` is - eval_hook = CheckpointHook( + ckpt_hook = CheckpointHook( interval=2, by_epoch=True, save_best='loss', rule='greater') - eval_hook.before_train(runner) + ckpt_hook.before_train(runner) metrics['loss'] = 1.0 - eval_hook.after_val_epoch(runner, metrics) - assert 'best_score' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('best_score') == 1.0 + ckpt_hook.after_val_epoch(runner, metrics) + self.assertEqual(runner.message_hub.get_info('best_score'), 1.0) # test multi `save_best` with one rule - eval_hook = CheckpointHook( + ckpt_hook = CheckpointHook( interval=2, save_best=['acc', 'mIoU'], rule='greater') - assert eval_hook.key_indicators == ['acc', 'mIoU'] - assert eval_hook.rules == ['greater', 'greater'] + self.assertEqual(ckpt_hook.key_indicators, ['acc', 'mIoU']) + self.assertEqual(ckpt_hook.rules, ['greater', 'greater']) # test multi `save_best` with multi rules - eval_hook = CheckpointHook( + ckpt_hook = CheckpointHook( interval=2, save_best=['FID', 'IS'], rule=['less', 'greater']) - assert eval_hook.key_indicators == ['FID', 'IS'] - assert eval_hook.rules == ['less', 'greater'] + self.assertEqual(ckpt_hook.key_indicators, ['FID', 'IS']) + self.assertEqual(ckpt_hook.rules, ['less', 'greater']) # test multi `save_best` with default rule - eval_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU']) - assert eval_hook.key_indicators == ['acc', 'mIoU'] - assert eval_hook.rules == ['greater', 'greater'] + ckpt_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU']) + self.assertEqual(ckpt_hook.key_indicators, ['acc', 'mIoU']) + self.assertEqual(ckpt_hook.rules, ['greater', 'greater']) runner.message_hub = MessageHub.get_instance( 'test_after_val_epoch_save_multi_best') - eval_hook.before_train(runner) + ckpt_hook.before_train(runner) metrics = dict(acc=0.5, mIoU=0.6) - eval_hook.after_val_epoch(runner, metrics) + ckpt_hook.after_val_epoch(runner, metrics) best_acc_name = 'best_acc_epoch_9.pth' - best_acc_path = eval_hook.file_client.join_path( - eval_hook.out_dir, best_acc_name) + best_acc_path = ckpt_hook.file_client.join_path( + ckpt_hook.out_dir, best_acc_name) best_mIoU_name = 'best_mIoU_epoch_9.pth' - best_mIoU_path = eval_hook.file_client.join_path( - eval_hook.out_dir, best_mIoU_name) - assert 'best_score_acc' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('best_score_acc') == 0.5 - assert 'best_score_mIoU' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('best_score_mIoU') == 0.6 - assert 'best_ckpt_acc' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('best_ckpt_acc') == best_acc_path - assert 'best_ckpt_mIoU' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('best_ckpt_mIoU') == best_mIoU_path + best_mIoU_path = ckpt_hook.file_client.join_path( + ckpt_hook.out_dir, best_mIoU_name) + self.assertEqual(runner.message_hub.get_info('best_score_acc'), 0.5) + + self.assertEqual(runner.message_hub.get_info('best_score_mIoU'), 0.6) + + self.assertEqual( + runner.message_hub.get_info('best_ckpt_acc'), best_acc_path) + + self.assertEqual( + runner.message_hub.get_info('best_ckpt_mIoU'), best_mIoU_path) # test behavior when by_epoch is False - runner = Mock() - runner.work_dir = tmp_path - runner.iter = 9 - runner.model = Mock() - runner.message_hub = MessageHub.get_instance( - 'test_after_val_epoch_by_epoch_is_false') + cfg = copy.deepcopy(self.iter_based_cfg) + runner = self.build_runner(cfg) + runner.train_loop._iter = 9 # check best ckpt name and best score metrics = {'acc': 0.5, 'map': 0.3} - eval_hook = CheckpointHook( + ckpt_hook = CheckpointHook( interval=2, by_epoch=False, save_best='acc', rule='greater') - eval_hook.before_train(runner) - eval_hook.after_val_epoch(runner, metrics) - assert eval_hook.key_indicators == ['acc'] - assert eval_hook.rules == ['greater'] + ckpt_hook.before_train(runner) + ckpt_hook.after_val_epoch(runner, metrics) + self.assertEqual(ckpt_hook.key_indicators, ['acc']) + self.assertEqual(ckpt_hook.rules, ['greater']) best_ckpt_name = 'best_acc_iter_9.pth' - best_ckpt_path = eval_hook.file_client.join_path( - eval_hook.out_dir, best_ckpt_name) - assert 'best_ckpt' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('best_ckpt') == best_ckpt_path - assert 'best_score' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('best_score') == 0.5 + best_ckpt_path = ckpt_hook.file_client.join_path( + ckpt_hook.out_dir, best_ckpt_name) + + self.assertEqual( + runner.message_hub.get_info('best_ckpt'), best_ckpt_path) + self.assertEqual(runner.message_hub.get_info('best_score'), 0.5) # check best score updating metrics['acc'] = 0.666 - eval_hook.after_val_epoch(runner, metrics) + ckpt_hook.after_val_epoch(runner, metrics) best_ckpt_name = 'best_acc_iter_9.pth' - best_ckpt_path = eval_hook.file_client.join_path( - eval_hook.out_dir, best_ckpt_name) - assert 'best_ckpt' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('best_ckpt') == best_ckpt_path - assert 'best_score' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('best_score') == 0.666 - # error when 'auto' in `save_best` list - with pytest.raises(AssertionError): - CheckpointHook(interval=2, save_best=['auto', 'acc']) - # error when one `save_best` with multi `rule` - with pytest.raises(AssertionError): - CheckpointHook( - interval=2, save_best='acc', rule=['greater', 'less']) + best_ckpt_path = ckpt_hook.file_client.join_path( + ckpt_hook.out_dir, best_ckpt_name) + self.assertEqual( + runner.message_hub.get_info('best_ckpt'), best_ckpt_path) + + self.assertEqual(runner.message_hub.get_info('best_score'), 0.666) # check best checkpoint name with `by_epoch` is False - eval_hook = CheckpointHook( + ckpt_hook = CheckpointHook( interval=2, by_epoch=False, save_best=['acc', 'mIoU']) - assert eval_hook.key_indicators == ['acc', 'mIoU'] - assert eval_hook.rules == ['greater', 'greater'] - runner.message_hub = MessageHub.get_instance( - 'test_after_val_epoch_save_multi_best_by_epoch_is_false') - eval_hook.before_train(runner) + ckpt_hook.before_train(runner) metrics = dict(acc=0.5, mIoU=0.6) - eval_hook.after_val_epoch(runner, metrics) + ckpt_hook.after_val_epoch(runner, metrics) best_acc_name = 'best_acc_iter_9.pth' - best_acc_path = eval_hook.file_client.join_path( - eval_hook.out_dir, best_acc_name) + best_acc_path = ckpt_hook.file_client.join_path( + ckpt_hook.out_dir, best_acc_name) best_mIoU_name = 'best_mIoU_iter_9.pth' - best_mIoU_path = eval_hook.file_client.join_path( - eval_hook.out_dir, best_mIoU_name) - assert 'best_score_acc' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('best_score_acc') == 0.5 - assert 'best_score_mIoU' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('best_score_mIoU') == 0.6 - assert 'best_ckpt_acc' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('best_ckpt_acc') == best_acc_path - assert 'best_ckpt_mIoU' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('best_ckpt_mIoU') == best_mIoU_path + best_mIoU_path = ckpt_hook.file_client.join_path( + ckpt_hook.out_dir, best_mIoU_name) - # after_val_epoch should not save last_checkpoint. - assert not osp.isfile(osp.join(runner.work_dir, 'last_checkpoint')) + self.assertEqual(runner.message_hub.get_info('best_score_acc'), 0.5) + self.assertEqual(runner.message_hub.get_info('best_score_mIoU'), 0.6) + self.assertEqual( + runner.message_hub.get_info('best_ckpt_acc'), best_acc_path) + self.assertEqual( + runner.message_hub.get_info('best_ckpt_mIoU'), best_mIoU_path) - def test_after_train_epoch(self, tmp_path): - runner = Mock() - work_dir = str(tmp_path) - runner.work_dir = tmp_path - runner.epoch = 9 - runner.model = Mock() - runner.message_hub = MessageHub.get_instance('test_after_train_epoch') + # after_val_epoch should not save last_checkpoint + self.assertFalse( + osp.isfile(osp.join(runner.work_dir, 'last_checkpoint'))) + + def test_after_train_epoch(self): + cfg = copy.deepcopy(self.epoch_based_cfg) + runner = self.build_runner(cfg) + runner.train_loop._epoch = 9 + runner.optim_wrapper = runner.build_optim_wrapper(runner.optim_wrapper) # by epoch is True checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) checkpoint_hook.before_train(runner) checkpoint_hook.after_train_epoch(runner) - assert (runner.epoch + 1) % 2 == 0 - assert 'last_ckpt' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('last_ckpt') == \ - osp.join(work_dir, 'epoch_10.pth') - last_ckpt_path = osp.join(work_dir, 'last_checkpoint') - assert osp.isfile(last_ckpt_path) + self.assertEqual((runner.epoch + 1) % 2, 0) + self.assertEqual( + runner.message_hub.get_info('last_ckpt'), + osp.join(cfg.work_dir, 'epoch_10.pth')) + + last_ckpt_path = osp.join(cfg.work_dir, 'last_checkpoint') + self.assertTrue(osp.isfile(last_ckpt_path)) + with open(last_ckpt_path) as f: filepath = f.read() - assert filepath == osp.join(work_dir, 'epoch_10.pth') + self.assertEqual(filepath, osp.join(cfg.work_dir, 'epoch_10.pth')) # epoch can not be evenly divided by 2 - runner.epoch = 10 + runner.train_loop._epoch = 10 checkpoint_hook.after_train_epoch(runner) - assert 'last_ckpt' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('last_ckpt') == \ - osp.join(work_dir, 'epoch_10.pth') + self.assertEqual( + runner.message_hub.get_info('last_ckpt'), + osp.join(cfg.work_dir, 'epoch_10.pth')) + runner.message_hub.runtime_info.clear() # by epoch is False - runner.epoch = 9 - runner.message_hub = MessageHub.get_instance('test_after_train_epoch1') + runner.train_loop._epoch = 9 checkpoint_hook = CheckpointHook(interval=2, by_epoch=False) checkpoint_hook.before_train(runner) checkpoint_hook.after_train_epoch(runner) - assert 'last_ckpt' not in runner.message_hub.runtime_info - - # # max_keep_ckpts > 0 - runner.work_dir = work_dir - os.system(f'touch {work_dir}/epoch_8.pth') - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, max_keep_ckpts=1) - checkpoint_hook.before_train(runner) - checkpoint_hook.after_train_epoch(runner) - assert (runner.epoch + 1) % 2 == 0 - assert not os.path.exists(f'{work_dir}/epoch_8.pth') - - # save_checkpoint of runner should be called with expected arguments - runner = Mock() - work_dir = str(tmp_path) - runner.work_dir = tmp_path - runner.epoch = 1 - runner.message_hub = MessageHub.get_instance('test_after_train_epoch2') - - checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) - checkpoint_hook.before_train(runner) - checkpoint_hook.after_train_epoch(runner) - - runner.save_checkpoint.assert_called_once_with( - runner.work_dir, - 'epoch_2.pth', - None, - backend_args=None, - by_epoch=True, - save_optimizer=True, - save_param_scheduler=True) - - def test_after_train_iter(self, tmp_path): - work_dir = str(tmp_path) - runner = Mock() - runner.work_dir = str(work_dir) - runner.iter = 9 - batch_idx = 9 - runner.model = Mock() - runner.message_hub = MessageHub.get_instance('test_after_train_iter') + self.assertNotIn('last_ckpt', runner.message_hub.runtime_info) + runner.message_hub.runtime_info.clear() + def test_after_train_iter(self): # by epoch is True + cfg = copy.deepcopy(self.epoch_based_cfg) + runner = self.build_runner(cfg) + runner.train_loop._iter = 9 + runner.optim_wrapper = runner.build_optim_wrapper(runner.optim_wrapper) checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) checkpoint_hook.before_train(runner) - checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) - assert 'last_ckpt' not in runner.message_hub.runtime_info + checkpoint_hook.after_train_iter(runner, batch_idx=9) + self.assertNotIn('last_ckpt', runner.message_hub.runtime_info) # by epoch is False checkpoint_hook = CheckpointHook(interval=2, by_epoch=False) checkpoint_hook.before_train(runner) - checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) - assert (runner.iter + 1) % 2 == 0 - assert 'last_ckpt' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('last_ckpt') == \ - osp.join(work_dir, 'iter_10.pth') + checkpoint_hook.after_train_iter(runner, batch_idx=9) + self.assertIn('last_ckpt', runner.message_hub.runtime_info) + self.assertEqual( + runner.message_hub.get_info('last_ckpt'), + osp.join(cfg.work_dir, 'iter_10.pth')) # epoch can not be evenly divided by 2 - runner.iter = 10 + runner.train_loop._iter = 10 checkpoint_hook.after_train_epoch(runner) - assert 'last_ckpt' in runner.message_hub.runtime_info and \ - runner.message_hub.get_info('last_ckpt') == \ - osp.join(work_dir, 'iter_10.pth') + self.assertEqual( + runner.message_hub.get_info('last_ckpt'), + osp.join(cfg.work_dir, 'iter_10.pth')) - # max_keep_ckpts > 0 - runner.iter = 9 - runner.work_dir = work_dir - os.system(f'touch {osp.join(work_dir, "iter_8.pth")}') - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=False, max_keep_ckpts=1) - checkpoint_hook.before_train(runner) - checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) - assert not os.path.exists(f'{work_dir}/iter_8.pth') - - def test_with_runner(self, tmp_path): - max_epoch = 10 - work_dir = osp.join(str(tmp_path), 'runner_test') - tmpl = '{}.pth' - save_interval = 2 + @parameterized.expand([['iter'], ['epoch']]) + def test_with_runner(self, training_type): + # Test interval in epoch based training + save_iterval = 2 + cfg = copy.deepcopy(getattr(self, f'{training_type}_based_cfg')) + setattr(cfg.train_cfg, f'max_{training_type}s', 11) checkpoint_cfg = dict( type='CheckpointHook', - interval=save_interval, - filename_tmpl=tmpl, - by_epoch=True) - runner = Runner( - model=ToyModel(), - work_dir=work_dir, - train_dataloader=dict( - dataset=DummyDataset(), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=3, - num_workers=0), - val_dataloader=dict( - dataset=DummyDataset(), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), - val_evaluator=dict(type=TriangleMetric, length=max_epoch), - optim_wrapper=OptimWrapper( - torch.optim.Adam(ToyModel().parameters())), - train_cfg=dict( - by_epoch=True, max_epochs=max_epoch, val_interval=1), - val_cfg=dict(), - default_hooks=dict(checkpoint=checkpoint_cfg)) + interval=save_iterval, + by_epoch=training_type == 'epoch') + cfg.default_hooks = dict(checkpoint=checkpoint_cfg) + runner = self.build_runner(cfg) runner.train() - for epoch in range(max_epoch): - if epoch % save_interval != 0 or epoch == 0: - continue - path = osp.join(work_dir, tmpl.format(epoch)) - assert osp.isfile(path=path) + + for i in range(1, 11): + if i == 0: + self.assertFalse( + osp.isfile( + osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + if i % 2 == 0: + self.assertTrue( + osp.isfile( + osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + + self.assertTrue( + osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth'))) + + # Test save_optimizer=False + ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + self.assertIn('optimizer', ckpt) + cfg.default_hooks.checkpoint.save_optimizer = False + runner = self.build_runner(cfg) + runner.train() + ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + self.assertNotIn('optimizer', ckpt) + + # Test save_param_scheduler=False + cfg.param_scheduler = [ + dict( + type='LinearLR', + start_factor=0.1, + begin=0, + end=500, + by_epoch=training_type == 'epoch') + ] + runner = self.build_runner(cfg) + runner.train() + ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + self.assertIn('param_schedulers', ckpt) + + cfg.default_hooks.checkpoint.save_param_scheduler = False + runner = self.build_runner(cfg) + runner.train() + ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + self.assertNotIn('param_schedulers', ckpt) + + # Test out_dir + out_dir = osp.join(self.temp_dir.name, 'out_dir') + cfg.default_hooks.checkpoint.out_dir = out_dir + runner = self.build_runner(cfg) + runner.train() + self.assertTrue( + osp.isfile( + osp.join(out_dir, osp.basename(cfg.work_dir), + f'{training_type}_11.pth'))) + + # Test max_keep_ckpts. + del cfg.default_hooks.checkpoint.out_dir + cfg.default_hooks.checkpoint.max_keep_ckpts = 1 + runner = self.build_runner(cfg) + runner.train() + self.assertTrue( + osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth'))) + + for i in range(10): + self.assertFalse( + osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + + # Test filename_tmpl + cfg.default_hooks.checkpoint.filename_tmpl = 'test_{}.pth' + runner = self.build_runner(cfg) + runner.train() + self.assertTrue(osp.isfile(osp.join(cfg.work_dir, 'test_10.pth'))) + + # Test save_best + cfg.default_hooks.checkpoint.save_best = 'acc' + cfg.val_evaluator = dict(type='TriangleMetric', length=11) + cfg.train_cfg.val_interval = 1 + runner = self.build_runner(cfg) + runner.train() + self.assertTrue( + osp.isfile(osp.join(cfg.work_dir, 'best_acc_test_5_pth')))