diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index bae8bd65..23f4a8ab 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -85,7 +85,7 @@ class CheckpointHook(Hook): accordingly. backend_args (dict, optional): Arguments to instantiate the prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. + `New in version 0.2.0.` published_keys (str, List[str], optional): If ``save_last`` is ``True`` or ``save_best`` is not ``None``, it will automatically publish model with keys in the list after training. diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index 6a80dcdd..e6469bb3 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -429,34 +429,37 @@ class TestCheckpointHook(RunnerTestCase): @parameterized.expand([['iter'], ['epoch']]) def test_with_runner(self, training_type): - # Test interval in epoch based training - save_iterval = 2 - cfg = copy.deepcopy(getattr(self, f'{training_type}_based_cfg')) - setattr(cfg.train_cfg, f'max_{training_type}s', 11) + common_cfg = getattr(self, f'{training_type}_based_cfg') + setattr(common_cfg.train_cfg, f'max_{training_type}s', 11) checkpoint_cfg = dict( type='CheckpointHook', - interval=save_iterval, + interval=2, by_epoch=training_type == 'epoch') - cfg.default_hooks = dict(checkpoint=checkpoint_cfg) + common_cfg.default_hooks = dict(checkpoint=checkpoint_cfg) + + # Test interval in epoch based training + cfg = copy.deepcopy(common_cfg) runner = self.build_runner(cfg) runner.train() for i in range(1, 11): - if i == 0: - self.assertFalse( - osp.isfile( - osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) - if i % 2 == 0: - self.assertTrue( - osp.isfile( - osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + self.assertEqual( + osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')), + i % 2 == 0) + # save_last=True self.assertTrue( osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth'))) + self.clear_work_dir() + # Test save_optimizer=False + cfg = copy.deepcopy(common_cfg) + runner = self.build_runner(cfg) + runner.train() ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) self.assertIn('optimizer', ckpt) + cfg.default_hooks.checkpoint.save_optimizer = False runner = self.build_runner(cfg) runner.train() @@ -464,6 +467,7 @@ class TestCheckpointHook(RunnerTestCase): self.assertNotIn('optimizer', ckpt) # Test save_param_scheduler=False + cfg = copy.deepcopy(common_cfg) cfg.param_scheduler = [ dict( type='LinearLR', @@ -483,7 +487,10 @@ class TestCheckpointHook(RunnerTestCase): ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) self.assertNotIn('param_schedulers', ckpt) + self.clear_work_dir() + # Test out_dir + cfg = copy.deepcopy(common_cfg) out_dir = osp.join(self.temp_dir.name, 'out_dir') cfg.default_hooks.checkpoint.out_dir = out_dir runner = self.build_runner(cfg) @@ -493,37 +500,54 @@ class TestCheckpointHook(RunnerTestCase): osp.join(out_dir, osp.basename(cfg.work_dir), f'{training_type}_11.pth'))) - # Test max_keep_ckpts. - del cfg.default_hooks.checkpoint.out_dir + self.clear_work_dir() + + # Test max_keep_ckpts + cfg = copy.deepcopy(common_cfg) + cfg.default_hooks.checkpoint.interval = 1 cfg.default_hooks.checkpoint.max_keep_ckpts = 1 runner = self.build_runner(cfg) runner.train() + print(os.listdir(cfg.work_dir)) self.assertTrue( - osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth'))) + osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth'))) - for i in range(10): + for i in range(11): self.assertFalse( osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + self.clear_work_dir() + # Test filename_tmpl + cfg = copy.deepcopy(common_cfg) cfg.default_hooks.checkpoint.filename_tmpl = 'test_{}.pth' runner = self.build_runner(cfg) runner.train() - self.assertTrue(osp.isfile(osp.join(cfg.work_dir, 'test_10.pth'))) + self.assertTrue(osp.isfile(osp.join(cfg.work_dir, 'test_11.pth'))) + + self.clear_work_dir() # Test save_best + cfg = copy.deepcopy(common_cfg) + cfg.default_hooks.checkpoint.interval = 1 cfg.default_hooks.checkpoint.save_best = 'test/acc' cfg.val_evaluator = dict(type='TriangleMetric', length=11) cfg.train_cfg.val_interval = 1 runner = self.build_runner(cfg) runner.train() - self.assertTrue( - osp.isfile(osp.join(cfg.work_dir, 'best_test_acc_test_5.pth'))) + best_ckpt = osp.join(cfg.work_dir, + f'best_test_acc_{training_type}_5.pth') + self.assertTrue(osp.isfile(best_ckpt)) + + self.clear_work_dir() # test save published keys + cfg = copy.deepcopy(common_cfg) cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict'] runner = self.build_runner(cfg) runner.train() ckpt_files = os.listdir(runner.work_dir) self.assertTrue( any(re.findall(r'-[\d\w]{8}\.pth', file) for file in ckpt_files)) + + self.clear_work_dir()