mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Let unit tests not affect each other (#1169)
This commit is contained in:
parent
5d4e72144a
commit
193b7fdfcc
@ -85,7 +85,7 @@ class CheckpointHook(Hook):
|
|||||||
accordingly.
|
accordingly.
|
||||||
backend_args (dict, optional): Arguments to instantiate the
|
backend_args (dict, optional): Arguments to instantiate the
|
||||||
prefix of uri corresponding backend. Defaults to None.
|
prefix of uri corresponding backend. Defaults to None.
|
||||||
New in v0.2.0.
|
`New in version 0.2.0.`
|
||||||
published_keys (str, List[str], optional): If ``save_last`` is ``True``
|
published_keys (str, List[str], optional): If ``save_last`` is ``True``
|
||||||
or ``save_best`` is not ``None``, it will automatically
|
or ``save_best`` is not ``None``, it will automatically
|
||||||
publish model with keys in the list after training.
|
publish model with keys in the list after training.
|
||||||
|
@ -429,34 +429,37 @@ class TestCheckpointHook(RunnerTestCase):
|
|||||||
|
|
||||||
@parameterized.expand([['iter'], ['epoch']])
|
@parameterized.expand([['iter'], ['epoch']])
|
||||||
def test_with_runner(self, training_type):
|
def test_with_runner(self, training_type):
|
||||||
# Test interval in epoch based training
|
common_cfg = getattr(self, f'{training_type}_based_cfg')
|
||||||
save_iterval = 2
|
setattr(common_cfg.train_cfg, f'max_{training_type}s', 11)
|
||||||
cfg = copy.deepcopy(getattr(self, f'{training_type}_based_cfg'))
|
|
||||||
setattr(cfg.train_cfg, f'max_{training_type}s', 11)
|
|
||||||
checkpoint_cfg = dict(
|
checkpoint_cfg = dict(
|
||||||
type='CheckpointHook',
|
type='CheckpointHook',
|
||||||
interval=save_iterval,
|
interval=2,
|
||||||
by_epoch=training_type == 'epoch')
|
by_epoch=training_type == 'epoch')
|
||||||
cfg.default_hooks = dict(checkpoint=checkpoint_cfg)
|
common_cfg.default_hooks = dict(checkpoint=checkpoint_cfg)
|
||||||
|
|
||||||
|
# Test interval in epoch based training
|
||||||
|
cfg = copy.deepcopy(common_cfg)
|
||||||
runner = self.build_runner(cfg)
|
runner = self.build_runner(cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
|
|
||||||
for i in range(1, 11):
|
for i in range(1, 11):
|
||||||
if i == 0:
|
self.assertEqual(
|
||||||
self.assertFalse(
|
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')),
|
||||||
osp.isfile(
|
i % 2 == 0)
|
||||||
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
|
||||||
if i % 2 == 0:
|
|
||||||
self.assertTrue(
|
|
||||||
osp.isfile(
|
|
||||||
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
|
||||||
|
|
||||||
|
# save_last=True
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
|
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
|
||||||
|
|
||||||
|
self.clear_work_dir()
|
||||||
|
|
||||||
# Test save_optimizer=False
|
# Test save_optimizer=False
|
||||||
|
cfg = copy.deepcopy(common_cfg)
|
||||||
|
runner = self.build_runner(cfg)
|
||||||
|
runner.train()
|
||||||
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
|
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
|
||||||
self.assertIn('optimizer', ckpt)
|
self.assertIn('optimizer', ckpt)
|
||||||
|
|
||||||
cfg.default_hooks.checkpoint.save_optimizer = False
|
cfg.default_hooks.checkpoint.save_optimizer = False
|
||||||
runner = self.build_runner(cfg)
|
runner = self.build_runner(cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
@ -464,6 +467,7 @@ class TestCheckpointHook(RunnerTestCase):
|
|||||||
self.assertNotIn('optimizer', ckpt)
|
self.assertNotIn('optimizer', ckpt)
|
||||||
|
|
||||||
# Test save_param_scheduler=False
|
# Test save_param_scheduler=False
|
||||||
|
cfg = copy.deepcopy(common_cfg)
|
||||||
cfg.param_scheduler = [
|
cfg.param_scheduler = [
|
||||||
dict(
|
dict(
|
||||||
type='LinearLR',
|
type='LinearLR',
|
||||||
@ -483,7 +487,10 @@ class TestCheckpointHook(RunnerTestCase):
|
|||||||
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
|
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
|
||||||
self.assertNotIn('param_schedulers', ckpt)
|
self.assertNotIn('param_schedulers', ckpt)
|
||||||
|
|
||||||
|
self.clear_work_dir()
|
||||||
|
|
||||||
# Test out_dir
|
# Test out_dir
|
||||||
|
cfg = copy.deepcopy(common_cfg)
|
||||||
out_dir = osp.join(self.temp_dir.name, 'out_dir')
|
out_dir = osp.join(self.temp_dir.name, 'out_dir')
|
||||||
cfg.default_hooks.checkpoint.out_dir = out_dir
|
cfg.default_hooks.checkpoint.out_dir = out_dir
|
||||||
runner = self.build_runner(cfg)
|
runner = self.build_runner(cfg)
|
||||||
@ -493,37 +500,54 @@ class TestCheckpointHook(RunnerTestCase):
|
|||||||
osp.join(out_dir, osp.basename(cfg.work_dir),
|
osp.join(out_dir, osp.basename(cfg.work_dir),
|
||||||
f'{training_type}_11.pth')))
|
f'{training_type}_11.pth')))
|
||||||
|
|
||||||
# Test max_keep_ckpts.
|
self.clear_work_dir()
|
||||||
del cfg.default_hooks.checkpoint.out_dir
|
|
||||||
|
# Test max_keep_ckpts
|
||||||
|
cfg = copy.deepcopy(common_cfg)
|
||||||
|
cfg.default_hooks.checkpoint.interval = 1
|
||||||
cfg.default_hooks.checkpoint.max_keep_ckpts = 1
|
cfg.default_hooks.checkpoint.max_keep_ckpts = 1
|
||||||
runner = self.build_runner(cfg)
|
runner = self.build_runner(cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
|
print(os.listdir(cfg.work_dir))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth')))
|
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
|
||||||
|
|
||||||
for i in range(10):
|
for i in range(11):
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
||||||
|
|
||||||
|
self.clear_work_dir()
|
||||||
|
|
||||||
# Test filename_tmpl
|
# Test filename_tmpl
|
||||||
|
cfg = copy.deepcopy(common_cfg)
|
||||||
cfg.default_hooks.checkpoint.filename_tmpl = 'test_{}.pth'
|
cfg.default_hooks.checkpoint.filename_tmpl = 'test_{}.pth'
|
||||||
runner = self.build_runner(cfg)
|
runner = self.build_runner(cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
self.assertTrue(osp.isfile(osp.join(cfg.work_dir, 'test_10.pth')))
|
self.assertTrue(osp.isfile(osp.join(cfg.work_dir, 'test_11.pth')))
|
||||||
|
|
||||||
|
self.clear_work_dir()
|
||||||
|
|
||||||
# Test save_best
|
# Test save_best
|
||||||
|
cfg = copy.deepcopy(common_cfg)
|
||||||
|
cfg.default_hooks.checkpoint.interval = 1
|
||||||
cfg.default_hooks.checkpoint.save_best = 'test/acc'
|
cfg.default_hooks.checkpoint.save_best = 'test/acc'
|
||||||
cfg.val_evaluator = dict(type='TriangleMetric', length=11)
|
cfg.val_evaluator = dict(type='TriangleMetric', length=11)
|
||||||
cfg.train_cfg.val_interval = 1
|
cfg.train_cfg.val_interval = 1
|
||||||
runner = self.build_runner(cfg)
|
runner = self.build_runner(cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
self.assertTrue(
|
best_ckpt = osp.join(cfg.work_dir,
|
||||||
osp.isfile(osp.join(cfg.work_dir, 'best_test_acc_test_5.pth')))
|
f'best_test_acc_{training_type}_5.pth')
|
||||||
|
self.assertTrue(osp.isfile(best_ckpt))
|
||||||
|
|
||||||
|
self.clear_work_dir()
|
||||||
|
|
||||||
# test save published keys
|
# test save published keys
|
||||||
|
cfg = copy.deepcopy(common_cfg)
|
||||||
cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict']
|
cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict']
|
||||||
runner = self.build_runner(cfg)
|
runner = self.build_runner(cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
ckpt_files = os.listdir(runner.work_dir)
|
ckpt_files = os.listdir(runner.work_dir)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
any(re.findall(r'-[\d\w]{8}\.pth', file) for file in ckpt_files))
|
any(re.findall(r'-[\d\w]{8}\.pth', file) for file in ckpt_files))
|
||||||
|
|
||||||
|
self.clear_work_dir()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user