mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Let unit tests not affect each other (#1169)
This commit is contained in:
parent
5d4e72144a
commit
193b7fdfcc
@ -85,7 +85,7 @@ class CheckpointHook(Hook):
|
||||
accordingly.
|
||||
backend_args (dict, optional): Arguments to instantiate the
|
||||
prefix of uri corresponding backend. Defaults to None.
|
||||
New in v0.2.0.
|
||||
`New in version 0.2.0.`
|
||||
published_keys (str, List[str], optional): If ``save_last`` is ``True``
|
||||
or ``save_best`` is not ``None``, it will automatically
|
||||
publish model with keys in the list after training.
|
||||
|
@ -429,34 +429,37 @@ class TestCheckpointHook(RunnerTestCase):
|
||||
|
||||
@parameterized.expand([['iter'], ['epoch']])
|
||||
def test_with_runner(self, training_type):
|
||||
# Test interval in epoch based training
|
||||
save_iterval = 2
|
||||
cfg = copy.deepcopy(getattr(self, f'{training_type}_based_cfg'))
|
||||
setattr(cfg.train_cfg, f'max_{training_type}s', 11)
|
||||
common_cfg = getattr(self, f'{training_type}_based_cfg')
|
||||
setattr(common_cfg.train_cfg, f'max_{training_type}s', 11)
|
||||
checkpoint_cfg = dict(
|
||||
type='CheckpointHook',
|
||||
interval=save_iterval,
|
||||
interval=2,
|
||||
by_epoch=training_type == 'epoch')
|
||||
cfg.default_hooks = dict(checkpoint=checkpoint_cfg)
|
||||
common_cfg.default_hooks = dict(checkpoint=checkpoint_cfg)
|
||||
|
||||
# Test interval in epoch based training
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
|
||||
for i in range(1, 11):
|
||||
if i == 0:
|
||||
self.assertFalse(
|
||||
osp.isfile(
|
||||
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
||||
if i % 2 == 0:
|
||||
self.assertTrue(
|
||||
osp.isfile(
|
||||
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
||||
self.assertEqual(
|
||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')),
|
||||
i % 2 == 0)
|
||||
|
||||
# save_last=True
|
||||
self.assertTrue(
|
||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
|
||||
|
||||
self.clear_work_dir()
|
||||
|
||||
# Test save_optimizer=False
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
|
||||
self.assertIn('optimizer', ckpt)
|
||||
|
||||
cfg.default_hooks.checkpoint.save_optimizer = False
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
@ -464,6 +467,7 @@ class TestCheckpointHook(RunnerTestCase):
|
||||
self.assertNotIn('optimizer', ckpt)
|
||||
|
||||
# Test save_param_scheduler=False
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
cfg.param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
@ -483,7 +487,10 @@ class TestCheckpointHook(RunnerTestCase):
|
||||
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
|
||||
self.assertNotIn('param_schedulers', ckpt)
|
||||
|
||||
self.clear_work_dir()
|
||||
|
||||
# Test out_dir
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
out_dir = osp.join(self.temp_dir.name, 'out_dir')
|
||||
cfg.default_hooks.checkpoint.out_dir = out_dir
|
||||
runner = self.build_runner(cfg)
|
||||
@ -493,37 +500,54 @@ class TestCheckpointHook(RunnerTestCase):
|
||||
osp.join(out_dir, osp.basename(cfg.work_dir),
|
||||
f'{training_type}_11.pth')))
|
||||
|
||||
# Test max_keep_ckpts.
|
||||
del cfg.default_hooks.checkpoint.out_dir
|
||||
self.clear_work_dir()
|
||||
|
||||
# Test max_keep_ckpts
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
cfg.default_hooks.checkpoint.interval = 1
|
||||
cfg.default_hooks.checkpoint.max_keep_ckpts = 1
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
print(os.listdir(cfg.work_dir))
|
||||
self.assertTrue(
|
||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth')))
|
||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
|
||||
|
||||
for i in range(10):
|
||||
for i in range(11):
|
||||
self.assertFalse(
|
||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
||||
|
||||
self.clear_work_dir()
|
||||
|
||||
# Test filename_tmpl
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
cfg.default_hooks.checkpoint.filename_tmpl = 'test_{}.pth'
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
self.assertTrue(osp.isfile(osp.join(cfg.work_dir, 'test_10.pth')))
|
||||
self.assertTrue(osp.isfile(osp.join(cfg.work_dir, 'test_11.pth')))
|
||||
|
||||
self.clear_work_dir()
|
||||
|
||||
# Test save_best
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
cfg.default_hooks.checkpoint.interval = 1
|
||||
cfg.default_hooks.checkpoint.save_best = 'test/acc'
|
||||
cfg.val_evaluator = dict(type='TriangleMetric', length=11)
|
||||
cfg.train_cfg.val_interval = 1
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
self.assertTrue(
|
||||
osp.isfile(osp.join(cfg.work_dir, 'best_test_acc_test_5.pth')))
|
||||
best_ckpt = osp.join(cfg.work_dir,
|
||||
f'best_test_acc_{training_type}_5.pth')
|
||||
self.assertTrue(osp.isfile(best_ckpt))
|
||||
|
||||
self.clear_work_dir()
|
||||
|
||||
# test save published keys
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict']
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
ckpt_files = os.listdir(runner.work_dir)
|
||||
self.assertTrue(
|
||||
any(re.findall(r'-[\d\w]{8}\.pth', file) for file in ckpt_files))
|
||||
|
||||
self.clear_work_dir()
|
||||
|
Loading…
x
Reference in New Issue
Block a user