[Refactor] Refactor checkpointhook unit tests (#789)

* Enhance config

* add unit test data

* Refacotr unitest of checkpointhook

* add comments

* Fix unit test

* remove _get_metric_scope

* tmp save

* Revert "remove _get_metric_scope"

This reverts commit eeb7a8c5ed2766bf773a9ed28f731fddacd10ac1.

* Revert "Revert "remove _get_metric_scope""

This reverts commit 5398255f6fb3dac8341f7d808f0d7d09350fcaae.

* Revert "tmp save"

This reverts commit cdc9919be8e0a78bbf264c060de2a4396c137d5a.

* clean the code

* Fix ut

* minor fix

* use str.replace
This commit is contained in:
Mashiro 2023-04-06 10:55:16 +08:00 committed by GitHub
parent dc931fd2c0
commit 2dbc8ed253
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 272 additions and 288 deletions

View File

@ -486,6 +486,8 @@ class CheckpointHook(Hook):
'is removed') 'is removed')
best_ckpt_name = f'best_{key_indicator}_{ckpt_filename}' best_ckpt_name = f'best_{key_indicator}_{ckpt_filename}'
# Replace illegal characters for filename with `_`
best_ckpt_name = best_ckpt_name.replace('/', '_')
if len(self.key_indicators) == 1: if len(self.key_indicators) == 1:
self.best_ckpt_path = self.file_client.join_path( # type: ignore # noqa: E501 self.best_ckpt_path = self.file_client.join_path( # type: ignore # noqa: E501
self.out_dir, best_ckpt_name) self.out_dir, best_ckpt_name)

View File

@ -1,56 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy
import os import os
import os.path as osp import os.path as osp
from unittest.mock import Mock import re
import pytest
import torch import torch
import torch.nn as nn from parameterized import parameterized
from torch.utils.data import Dataset
from mmengine.evaluator import BaseMetric from mmengine.evaluator import BaseMetric
from mmengine.fileio import FileClient, LocalBackend from mmengine.fileio import FileClient, LocalBackend
from mmengine.hooks import CheckpointHook from mmengine.hooks import CheckpointHook
from mmengine.logging import MessageHub from mmengine.logging import MessageHub
from mmengine.model import BaseModel from mmengine.registry import METRICS
from mmengine.optim import OptimWrapper from mmengine.testing import RunnerTestCase
from mmengine.runner import Runner
class ToyModel(BaseModel):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)
def forward(self, inputs, data_sample, mode='tensor'):
labels = torch.stack(data_sample)
inputs = torch.stack(inputs)
outputs = self.linear(inputs)
if mode == 'tensor':
return outputs
elif mode == 'loss':
loss = (labels - outputs).sum()
outputs = dict(loss=loss)
return outputs
else:
return outputs
class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
return dict(inputs=self.data[index], data_sample=self.label[index])
class TriangleMetric(BaseMetric): class TriangleMetric(BaseMetric):
@ -72,119 +34,143 @@ class TriangleMetric(BaseMetric):
return dict(acc=acc) return dict(acc=acc)
class TestCheckpointHook: class TestCheckpointHook(RunnerTestCase):
def test_init(self, tmp_path): def setUp(self):
super().setUp()
METRICS.register_module(module=TriangleMetric, force=True)
def tearDown(self):
return METRICS.module_dict.clear()
def test_init(self):
# Test file_client_args and backend_args # Test file_client_args and backend_args
# TODO: Refactor this test case # TODO: Refactor this test case
# with pytest.warns( # with self.assertWarnsRegex(
# DeprecationWarning, # DeprecationWarning,
# match='"file_client_args" will be deprecated in future'): # '"file_client_args" will be deprecated in future'):
# CheckpointHook(file_client_args={'backend': 'disk'}) # CheckpointHook(file_client_args={'backend': 'disk'})
with pytest.raises( with self.assertRaisesRegex(
ValueError, ValueError,
match='"file_client_args" and "backend_args" cannot be set ' '"file_client_args" and "backend_args" cannot be set '
'at the same time'): 'at the same time'):
CheckpointHook( CheckpointHook(
file_client_args={'backend': 'disk'}, file_client_args={'backend': 'disk'},
backend_args={'backend': 'local'}) backend_args={'backend': 'local'})
def test_before_train(self, tmp_path): # Test save best
runner = Mock() CheckpointHook(save_best='acc')
work_dir = str(tmp_path) CheckpointHook(save_best=['acc'])
runner.work_dir = work_dir
with self.assertRaisesRegex(AssertionError, '"save_best" should be'):
CheckpointHook(save_best=dict(acc='acc'))
# error when 'auto' in `save_best` list
with self.assertRaisesRegex(AssertionError, 'Only support one'):
CheckpointHook(interval=2, save_best=['auto', 'acc'])
# Test rules
CheckpointHook(save_best=['acc', 'mAcc'], rule='greater')
with self.assertRaisesRegex(AssertionError, '"rule" should be a str'):
CheckpointHook(save_best=['acc'], rule=1)
with self.assertRaisesRegex(AssertionError,
'Number of "rule" must be'):
CheckpointHook(save_best=['acc'], rule=['greater', 'loss'])
# Test greater_keys
hook = CheckpointHook(greater_keys='acc')
self.assertEqual(hook.greater_keys, ('acc', ))
hook = CheckpointHook(greater_keys=['acc'])
self.assertEqual(hook.greater_keys, ['acc'])
hook = CheckpointHook(
interval=2, by_epoch=False, save_best=['acc', 'mIoU'])
self.assertEqual(hook.key_indicators, ['acc', 'mIoU'])
self.assertEqual(hook.rules, ['greater', 'greater'])
# Test less keys
hook = CheckpointHook(less_keys='loss_cls')
self.assertEqual(hook.less_keys, ('loss_cls', ))
hook = CheckpointHook(less_keys=['loss_cls'])
self.assertEqual(hook.less_keys, ['loss_cls'])
def test_before_train(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
# file_client_args is None # file_client_args is None
checkpoint_hook = CheckpointHook() checkpoint_hook = CheckpointHook()
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
assert isinstance(checkpoint_hook.file_client, FileClient) self.assertIsInstance(checkpoint_hook.file_client, FileClient)
assert isinstance(checkpoint_hook.file_backend, LocalBackend) self.assertIsInstance(checkpoint_hook.file_backend, LocalBackend)
# file_client_args is not None # file_client_args is not None
checkpoint_hook = CheckpointHook(file_client_args={'backend': 'disk'}) checkpoint_hook = CheckpointHook(file_client_args={'backend': 'disk'})
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
assert isinstance(checkpoint_hook.file_client, FileClient) self.assertIsInstance(checkpoint_hook.file_client, FileClient)
# file_backend is the alias of file_client # file_backend is the alias of file_client
assert checkpoint_hook.file_backend is checkpoint_hook.file_client self.assertIs(checkpoint_hook.file_backend,
checkpoint_hook.file_client)
# the out_dir of the checkpoint hook is None # the out_dir of the checkpoint hook is None
checkpoint_hook = CheckpointHook(interval=1, by_epoch=True) checkpoint_hook = CheckpointHook(interval=1, by_epoch=True)
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
assert checkpoint_hook.out_dir == runner.work_dir self.assertEqual(checkpoint_hook.out_dir, runner.work_dir)
# the out_dir of the checkpoint hook is not None # the out_dir of the checkpoint hook is not None
checkpoint_hook = CheckpointHook( checkpoint_hook = CheckpointHook(
interval=1, by_epoch=True, out_dir='test_dir') interval=1, by_epoch=True, out_dir='test_dir')
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
assert checkpoint_hook.out_dir == osp.join( self.assertEqual(checkpoint_hook.out_dir,
'test_dir', osp.join(osp.basename(work_dir))) osp.join('test_dir', osp.basename(cfg.work_dir)))
runner.message_hub = MessageHub.get_instance('test_before_train') # If `save_best` is a list of string, the path to save the best
# no 'best_ckpt_path' in runtime_info # checkpoint will be defined in attribute `best_ckpt_path_dict`.
checkpoint_hook = CheckpointHook(interval=1, save_best=['acc', 'mIoU']) checkpoint_hook = CheckpointHook(interval=1, save_best=['acc', 'mIoU'])
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
assert checkpoint_hook.best_ckpt_path_dict == dict(acc=None, mIoU=None) self.assertEqual(checkpoint_hook.best_ckpt_path_dict,
assert not hasattr(checkpoint_hook, 'best_ckpt_path') dict(acc=None, mIoU=None))
self.assertFalse(hasattr(checkpoint_hook, 'best_ckpt_path'))
# only one 'best_ckpt_path' in runtime_info # Resume 'best_ckpt_path' from message_hub
runner.message_hub.update_info('best_ckpt_acc', 'best_acc') runner.message_hub.update_info('best_ckpt_acc', 'best_acc')
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
assert checkpoint_hook.best_ckpt_path_dict == dict( self.assertEqual(checkpoint_hook.best_ckpt_path_dict,
acc='best_acc', mIoU=None) dict(acc='best_acc', mIoU=None))
# no 'best_ckpt_path' in runtime_info # If `save_best` is a string, the path to save best ckpt will be
# defined in attribute `best_ckpt_path`
checkpoint_hook = CheckpointHook(interval=1, save_best='acc') checkpoint_hook = CheckpointHook(interval=1, save_best='acc')
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
assert checkpoint_hook.best_ckpt_path is None self.assertIsNone(checkpoint_hook.best_ckpt_path)
assert not hasattr(checkpoint_hook, 'best_ckpt_path_dict') self.assertFalse(hasattr(checkpoint_hook, 'best_ckpt_path_dict'))
# 'best_ckpt_path' in runtime_info # Resume `best_ckpt` path from message_hub
runner.message_hub.update_info('best_ckpt', 'best_ckpt') runner.message_hub.update_info('best_ckpt', 'best_ckpt')
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
assert checkpoint_hook.best_ckpt_path == 'best_ckpt' self.assertEqual(checkpoint_hook.best_ckpt_path, 'best_ckpt')
def test_after_val_epoch(self, tmp_path): def test_after_val_epoch(self):
runner = Mock() cfg = copy.deepcopy(self.epoch_based_cfg)
runner.work_dir = tmp_path runner = self.build_runner(cfg)
runner.epoch = 9 runner.train_loop._epoch = 9
runner.model = Mock()
runner.logger.warning = Mock()
runner.message_hub = MessageHub.get_instance('test_after_val_epoch')
with pytest.raises(ValueError):
# key_indicator must be valid when rule_map is None
CheckpointHook(interval=2, by_epoch=True, save_best='unsupport')
with pytest.raises(KeyError):
# rule must be in keys of rule_map
CheckpointHook(
interval=2, by_epoch=True, save_best='auto', rule='unsupport')
# if metrics is an empty dict, print a warning information # if metrics is an empty dict, print a warning information
checkpoint_hook = CheckpointHook( with self.assertLogs(runner.logger, level='WARNING'):
interval=2, by_epoch=True, save_best='auto') checkpoint_hook = CheckpointHook(
checkpoint_hook.after_val_epoch(runner, {}) interval=2, by_epoch=True, save_best='auto')
runner.logger.warning.assert_called_once() checkpoint_hook.after_val_epoch(runner, {})
# test error when number of rules and metrics are not same # if save_best is None,no best_ckpt meta should be stored
with pytest.raises(AssertionError) as assert_error:
CheckpointHook(
interval=1,
save_best=['mIoU', 'acc'],
rule=['greater', 'greater', 'less'],
by_epoch=True)
error_message = ('Number of "rule" must be 1 or the same as number of '
'"save_best", but got 3.')
assert error_message in str(assert_error.value)
# if save_best is None, no best_ckpt meta should be stored
checkpoint_hook = CheckpointHook( checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best=None) interval=2, by_epoch=True, save_best=None)
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, {}) checkpoint_hook.after_val_epoch(runner, {})
assert 'best_score' not in runner.message_hub.runtime_info self.assertNotIn('best_score', runner.message_hub.runtime_info)
assert 'best_ckpt' not in runner.message_hub.runtime_info self.assertNotIn('best_ckpt', runner.message_hub.runtime_info)
# when `save_best` is set to `auto`, first metric will be used. # when `save_best` is set to `auto`, first metric will be used.
metrics = {'acc': 0.5, 'map': 0.3} metrics = {'acc': 0.5, 'map': 0.3}
@ -195,12 +181,11 @@ class TestCheckpointHook:
best_ckpt_name = 'best_acc_epoch_9.pth' best_ckpt_name = 'best_acc_epoch_9.pth'
best_ckpt_path = checkpoint_hook.file_client.join_path( best_ckpt_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_ckpt_name) checkpoint_hook.out_dir, best_ckpt_name)
assert checkpoint_hook.key_indicators == ['acc'] self.assertEqual(checkpoint_hook.key_indicators, ['acc'])
assert checkpoint_hook.rules == ['greater'] self.assertEqual(checkpoint_hook.rules, ['greater'])
assert 'best_score' in runner.message_hub.runtime_info and \ self.assertEqual(runner.message_hub.get_info('best_score'), 0.5)
runner.message_hub.get_info('best_score') == 0.5 self.assertEqual(
assert 'best_ckpt' in runner.message_hub.runtime_info and \ runner.message_hub.get_info('best_ckpt'), best_ckpt_path)
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
# # when `save_best` is set to `acc`, it should update greater value # # when `save_best` is set to `acc`, it should update greater value
checkpoint_hook = CheckpointHook( checkpoint_hook = CheckpointHook(
@ -208,8 +193,7 @@ class TestCheckpointHook:
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
metrics['acc'] = 0.8 metrics['acc'] = 0.8
checkpoint_hook.after_val_epoch(runner, metrics) checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \ self.assertEqual(runner.message_hub.get_info('best_score'), 0.8)
runner.message_hub.get_info('best_score') == 0.8
# # when `save_best` is set to `loss`, it should update less value # # when `save_best` is set to `loss`, it should update less value
checkpoint_hook = CheckpointHook( checkpoint_hook = CheckpointHook(
@ -219,8 +203,7 @@ class TestCheckpointHook:
checkpoint_hook.after_val_epoch(runner, metrics) checkpoint_hook.after_val_epoch(runner, metrics)
metrics['loss'] = 0.5 metrics['loss'] = 0.5
checkpoint_hook.after_val_epoch(runner, metrics) checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \ self.assertEqual(runner.message_hub.get_info('best_score'), 0.5)
runner.message_hub.get_info('best_score') == 0.5
# when `rule` is set to `less`,then it should update less value # when `rule` is set to `less`,then it should update less value
# no matter what `save_best` is # no matter what `save_best` is
@ -229,8 +212,7 @@ class TestCheckpointHook:
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
metrics['acc'] = 0.3 metrics['acc'] = 0.3
checkpoint_hook.after_val_epoch(runner, metrics) checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \ self.assertEqual(runner.message_hub.get_info('best_score'), 0.3)
runner.message_hub.get_info('best_score') == 0.3
# # when `rule` is set to `greater`,then it should update greater value # # when `rule` is set to `greater`,then it should update greater value
# # no matter what `save_best` is # # no matter what `save_best` is
@ -239,25 +221,24 @@ class TestCheckpointHook:
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
metrics['loss'] = 1.0 metrics['loss'] = 1.0
checkpoint_hook.after_val_epoch(runner, metrics) checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \ self.assertEqual(runner.message_hub.get_info('best_score'), 1.0)
runner.message_hub.get_info('best_score') == 1.0
# test multi `save_best` with one rule # test multi `save_best` with one rule
checkpoint_hook = CheckpointHook( checkpoint_hook = CheckpointHook(
interval=2, save_best=['acc', 'mIoU'], rule='greater') interval=2, save_best=['acc', 'mIoU'], rule='greater')
assert checkpoint_hook.key_indicators == ['acc', 'mIoU'] self.assertEqual(checkpoint_hook.key_indicators, ['acc', 'mIoU'])
assert checkpoint_hook.rules == ['greater', 'greater'] self.assertEqual(checkpoint_hook.rules, ['greater', 'greater'])
# test multi `save_best` with multi rules # test multi `save_best` with multi rules
checkpoint_hook = CheckpointHook( checkpoint_hook = CheckpointHook(
interval=2, save_best=['FID', 'IS'], rule=['less', 'greater']) interval=2, save_best=['FID', 'IS'], rule=['less', 'greater'])
assert checkpoint_hook.key_indicators == ['FID', 'IS'] self.assertEqual(checkpoint_hook.key_indicators, ['FID', 'IS'])
assert checkpoint_hook.rules == ['less', 'greater'] self.assertEqual(checkpoint_hook.rules, ['less', 'greater'])
# test multi `save_best` with default rule # test multi `save_best` with default rule
checkpoint_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU']) checkpoint_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU'])
assert checkpoint_hook.key_indicators == ['acc', 'mIoU'] self.assertEqual(checkpoint_hook.key_indicators, ['acc', 'mIoU'])
assert checkpoint_hook.rules == ['greater', 'greater'] self.assertEqual(checkpoint_hook.rules, ['greater', 'greater'])
runner.message_hub = MessageHub.get_instance( runner.message_hub = MessageHub.get_instance(
'test_after_val_epoch_save_multi_best') 'test_after_val_epoch_save_multi_best')
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
@ -269,22 +250,17 @@ class TestCheckpointHook:
best_mIoU_name = 'best_mIoU_epoch_9.pth' best_mIoU_name = 'best_mIoU_epoch_9.pth'
best_mIoU_path = checkpoint_hook.file_client.join_path( best_mIoU_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_mIoU_name) checkpoint_hook.out_dir, best_mIoU_name)
assert 'best_score_acc' in runner.message_hub.runtime_info and \ self.assertEqual(runner.message_hub.get_info('best_score_acc'), 0.5)
runner.message_hub.get_info('best_score_acc') == 0.5 self.assertEqual(runner.message_hub.get_info('best_score_mIoU'), 0.6)
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \ self.assertEqual(
runner.message_hub.get_info('best_score_mIoU') == 0.6 runner.message_hub.get_info('best_ckpt_acc'), best_acc_path)
assert 'best_ckpt_acc' in runner.message_hub.runtime_info and \ self.assertEqual(
runner.message_hub.get_info('best_ckpt_acc') == best_acc_path runner.message_hub.get_info('best_ckpt_mIoU'), best_mIoU_path)
assert 'best_ckpt_mIoU' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_ckpt_mIoU') == best_mIoU_path
# test behavior when by_epoch is False # test behavior when by_epoch is False
runner = Mock() cfg = copy.deepcopy(self.iter_based_cfg)
runner.work_dir = tmp_path runner = self.build_runner(cfg)
runner.iter = 9 runner.train_loop._iter = 9
runner.model = Mock()
runner.message_hub = MessageHub.get_instance(
'test_after_val_epoch_by_epoch_is_false')
# check best ckpt name and best score # check best ckpt name and best score
metrics = {'acc': 0.5, 'map': 0.3} metrics = {'acc': 0.5, 'map': 0.3}
@ -292,15 +268,15 @@ class TestCheckpointHook:
interval=2, by_epoch=False, save_best='acc', rule='greater') interval=2, by_epoch=False, save_best='acc', rule='greater')
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, metrics) checkpoint_hook.after_val_epoch(runner, metrics)
assert checkpoint_hook.key_indicators == ['acc'] self.assertEqual(checkpoint_hook.key_indicators, ['acc'])
assert checkpoint_hook.rules == ['greater'] self.assertEqual(checkpoint_hook.rules, ['greater'])
best_ckpt_name = 'best_acc_iter_9.pth' best_ckpt_name = 'best_acc_iter_9.pth'
best_ckpt_path = checkpoint_hook.file_client.join_path( best_ckpt_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_ckpt_name) checkpoint_hook.out_dir, best_ckpt_name)
assert 'best_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_ckpt') == best_ckpt_path self.assertEqual(
assert 'best_score' in runner.message_hub.runtime_info and \ runner.message_hub.get_info('best_ckpt'), best_ckpt_path)
runner.message_hub.get_info('best_score') == 0.5 self.assertEqual(runner.message_hub.get_info('best_score'), 0.5)
# check best score updating # check best score updating
metrics['acc'] = 0.666 metrics['acc'] = 0.666
@ -308,25 +284,13 @@ class TestCheckpointHook:
best_ckpt_name = 'best_acc_iter_9.pth' best_ckpt_name = 'best_acc_iter_9.pth'
best_ckpt_path = checkpoint_hook.file_client.join_path( best_ckpt_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_ckpt_name) checkpoint_hook.out_dir, best_ckpt_name)
assert 'best_ckpt' in runner.message_hub.runtime_info and \ self.assertEqual(
runner.message_hub.get_info('best_ckpt') == best_ckpt_path runner.message_hub.get_info('best_ckpt'), best_ckpt_path)
assert 'best_score' in runner.message_hub.runtime_info and \ self.assertEqual(runner.message_hub.get_info('best_score'), 0.666)
runner.message_hub.get_info('best_score') == 0.666
# error when 'auto' in `save_best` list
with pytest.raises(AssertionError):
CheckpointHook(interval=2, save_best=['auto', 'acc'])
# error when one `save_best` with multi `rule`
with pytest.raises(AssertionError):
CheckpointHook(
interval=2, save_best='acc', rule=['greater', 'less'])
# check best checkpoint name with `by_epoch` is False # check best checkpoint name with `by_epoch` is False
checkpoint_hook = CheckpointHook( checkpoint_hook = CheckpointHook(
interval=2, by_epoch=False, save_best=['acc', 'mIoU']) interval=2, by_epoch=False, save_best=['acc', 'mIoU'])
assert checkpoint_hook.key_indicators == ['acc', 'mIoU']
assert checkpoint_hook.rules == ['greater', 'greater']
runner.message_hub = MessageHub.get_instance(
'test_after_val_epoch_save_multi_best_by_epoch_is_false')
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
metrics = dict(acc=0.5, mIoU=0.6) metrics = dict(acc=0.5, mIoU=0.6)
checkpoint_hook.after_val_epoch(runner, metrics) checkpoint_hook.after_val_epoch(runner, metrics)
@ -336,162 +300,180 @@ class TestCheckpointHook:
best_mIoU_name = 'best_mIoU_iter_9.pth' best_mIoU_name = 'best_mIoU_iter_9.pth'
best_mIoU_path = checkpoint_hook.file_client.join_path( best_mIoU_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_mIoU_name) checkpoint_hook.out_dir, best_mIoU_name)
assert 'best_score_acc' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score_acc') == 0.5
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score_mIoU') == 0.6
assert 'best_ckpt_acc' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_ckpt_acc') == best_acc_path
assert 'best_ckpt_mIoU' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_ckpt_mIoU') == best_mIoU_path
# after_val_epoch should not save last_checkpoint. self.assertEqual(runner.message_hub.get_info('best_score_acc'), 0.5)
assert not osp.isfile(osp.join(runner.work_dir, 'last_checkpoint')) self.assertEqual(runner.message_hub.get_info('best_score_mIoU'), 0.6)
self.assertEqual(
runner.message_hub.get_info('best_ckpt_acc'), best_acc_path)
self.assertEqual(
runner.message_hub.get_info('best_ckpt_mIoU'), best_mIoU_path)
def test_after_train_epoch(self, tmp_path): # after_val_epoch should not save last_checkpoint
runner = Mock() self.assertFalse(
work_dir = str(tmp_path) osp.isfile(osp.join(runner.work_dir, 'last_checkpoint')))
runner.work_dir = tmp_path
runner.epoch = 9 def test_after_train_epoch(self):
runner.model = Mock() cfg = copy.deepcopy(self.epoch_based_cfg)
runner.message_hub = MessageHub.get_instance('test_after_train_epoch') runner = self.build_runner(cfg)
runner.train_loop._epoch = 9
runner.optim_wrapper = runner.build_optim_wrapper(runner.optim_wrapper)
# by epoch is True # by epoch is True
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
checkpoint_hook.after_train_epoch(runner) checkpoint_hook.after_train_epoch(runner)
assert (runner.epoch + 1) % 2 == 0 self.assertEqual((runner.epoch + 1) % 2, 0)
assert 'last_ckpt' in runner.message_hub.runtime_info and \ self.assertEqual(
runner.message_hub.get_info('last_ckpt') == \ runner.message_hub.get_info('last_ckpt'),
osp.join(work_dir, 'epoch_10.pth') osp.join(cfg.work_dir, 'epoch_10.pth'))
last_ckpt_path = osp.join(work_dir, 'last_checkpoint')
assert osp.isfile(last_ckpt_path) last_ckpt_path = osp.join(cfg.work_dir, 'last_checkpoint')
self.assertTrue(osp.isfile(last_ckpt_path))
with open(last_ckpt_path) as f: with open(last_ckpt_path) as f:
filepath = f.read() filepath = f.read()
assert filepath == osp.join(work_dir, 'epoch_10.pth') self.assertEqual(filepath, osp.join(cfg.work_dir, 'epoch_10.pth'))
# epoch can not be evenly divided by 2 # epoch can not be evenly divided by 2
runner.epoch = 10 runner.train_loop._epoch = 10
checkpoint_hook.after_train_epoch(runner) checkpoint_hook.after_train_epoch(runner)
assert 'last_ckpt' in runner.message_hub.runtime_info and \ self.assertEqual(
runner.message_hub.get_info('last_ckpt') == \ runner.message_hub.get_info('last_ckpt'),
osp.join(work_dir, 'epoch_10.pth') osp.join(cfg.work_dir, 'epoch_10.pth'))
runner.message_hub.runtime_info.clear()
# by epoch is False # by epoch is False
runner.epoch = 9 runner.train_loop._epoch = 9
runner.message_hub = MessageHub.get_instance('test_after_train_epoch1')
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False) checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
checkpoint_hook.after_train_epoch(runner) checkpoint_hook.after_train_epoch(runner)
assert 'last_ckpt' not in runner.message_hub.runtime_info self.assertNotIn('last_ckpt', runner.message_hub.runtime_info)
runner.message_hub.runtime_info.clear()
# # max_keep_ckpts > 0
runner.work_dir = work_dir
os.system(f'touch {work_dir}/epoch_8.pth')
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, max_keep_ckpts=1)
checkpoint_hook.before_train(runner)
checkpoint_hook.after_train_epoch(runner)
assert (runner.epoch + 1) % 2 == 0
assert not os.path.exists(f'{work_dir}/epoch_8.pth')
# save_checkpoint of runner should be called with expected arguments
runner = Mock()
work_dir = str(tmp_path)
runner.work_dir = tmp_path
runner.epoch = 1
runner.message_hub = MessageHub.get_instance('test_after_train_epoch2')
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
checkpoint_hook.before_train(runner)
checkpoint_hook.after_train_epoch(runner)
runner.save_checkpoint.assert_called_once_with(
runner.work_dir,
'epoch_2.pth',
None,
backend_args=None,
by_epoch=True,
save_optimizer=True,
save_param_scheduler=True)
def test_after_train_iter(self, tmp_path):
work_dir = str(tmp_path)
runner = Mock()
runner.work_dir = str(work_dir)
runner.iter = 9
batch_idx = 9
runner.model = Mock()
runner.message_hub = MessageHub.get_instance('test_after_train_iter')
def test_after_train_iter(self):
# by epoch is True # by epoch is True
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
runner.train_loop._iter = 9
runner.optim_wrapper = runner.build_optim_wrapper(runner.optim_wrapper)
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) checkpoint_hook.after_train_iter(runner, batch_idx=9)
assert 'last_ckpt' not in runner.message_hub.runtime_info self.assertNotIn('last_ckpt', runner.message_hub.runtime_info)
# by epoch is False # by epoch is False
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False) checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) checkpoint_hook.after_train_iter(runner, batch_idx=9)
assert (runner.iter + 1) % 2 == 0 self.assertIn('last_ckpt', runner.message_hub.runtime_info)
assert 'last_ckpt' in runner.message_hub.runtime_info and \ self.assertEqual(
runner.message_hub.get_info('last_ckpt') == \ runner.message_hub.get_info('last_ckpt'),
osp.join(work_dir, 'iter_10.pth') osp.join(cfg.work_dir, 'iter_10.pth'))
# epoch can not be evenly divided by 2 # epoch can not be evenly divided by 2
runner.iter = 10 runner.train_loop._iter = 10
checkpoint_hook.after_train_epoch(runner) checkpoint_hook.after_train_epoch(runner)
assert 'last_ckpt' in runner.message_hub.runtime_info and \ self.assertEqual(
runner.message_hub.get_info('last_ckpt') == \ runner.message_hub.get_info('last_ckpt'),
osp.join(work_dir, 'iter_10.pth') osp.join(cfg.work_dir, 'iter_10.pth'))
# max_keep_ckpts > 0 @parameterized.expand([['iter'], ['epoch']])
runner.iter = 9 def test_with_runner(self, training_type):
runner.work_dir = work_dir # Test interval in epoch based training
os.system(f'touch {osp.join(work_dir, "iter_8.pth")}') save_iterval = 2
checkpoint_hook = CheckpointHook( cfg = copy.deepcopy(getattr(self, f'{training_type}_based_cfg'))
interval=2, by_epoch=False, max_keep_ckpts=1) setattr(cfg.train_cfg, f'max_{training_type}s', 11)
checkpoint_hook.before_train(runner)
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
assert not os.path.exists(f'{work_dir}/iter_8.pth')
def test_with_runner(self, tmp_path):
max_epoch = 10
work_dir = osp.join(str(tmp_path), 'runner_test')
tmpl = '{}.pth'
save_interval = 2
checkpoint_cfg = dict( checkpoint_cfg = dict(
type='CheckpointHook', type='CheckpointHook',
interval=save_interval, interval=save_iterval,
filename_tmpl=tmpl, by_epoch=training_type == 'epoch')
by_epoch=True, cfg.default_hooks = dict(checkpoint=checkpoint_cfg)
save_best='test/acc', runner = self.build_runner(cfg)
rule='less',
published_keys=['meta', 'state_dict'])
runner = Runner(
model=ToyModel(),
work_dir=work_dir,
train_dataloader=dict(
dataset=DummyDataset(),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
val_dataloader=dict(
dataset=DummyDataset(),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
val_evaluator=dict(type=TriangleMetric, length=max_epoch),
optim_wrapper=OptimWrapper(
torch.optim.Adam(ToyModel().parameters())),
train_cfg=dict(
by_epoch=True, max_epochs=max_epoch, val_interval=1),
val_cfg=dict(),
default_hooks=dict(checkpoint=checkpoint_cfg))
runner.train() runner.train()
for epoch in range(max_epoch):
if epoch % save_interval != 0 or epoch == 0: for i in range(1, 11):
continue if i == 0:
path = osp.join(work_dir, tmpl.format(epoch)) self.assertFalse(
assert osp.isfile(path=path) osp.isfile(
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
if i % 2 == 0:
self.assertTrue(
osp.isfile(
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
self.assertTrue(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
# Test save_optimizer=False
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
self.assertIn('optimizer', ckpt)
cfg.default_hooks.checkpoint.save_optimizer = False
runner = self.build_runner(cfg)
runner.train()
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
self.assertNotIn('optimizer', ckpt)
# Test save_param_scheduler=False
cfg.param_scheduler = [
dict(
type='LinearLR',
start_factor=0.1,
begin=0,
end=500,
by_epoch=training_type == 'epoch')
]
runner = self.build_runner(cfg)
runner.train()
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
self.assertIn('param_schedulers', ckpt)
cfg.default_hooks.checkpoint.save_param_scheduler = False
runner = self.build_runner(cfg)
runner.train()
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
self.assertNotIn('param_schedulers', ckpt)
# Test out_dir
out_dir = osp.join(self.temp_dir.name, 'out_dir')
cfg.default_hooks.checkpoint.out_dir = out_dir
runner = self.build_runner(cfg)
runner.train()
self.assertTrue(
osp.isfile(
osp.join(out_dir, osp.basename(cfg.work_dir),
f'{training_type}_11.pth')))
# Test max_keep_ckpts.
del cfg.default_hooks.checkpoint.out_dir
cfg.default_hooks.checkpoint.max_keep_ckpts = 1
runner = self.build_runner(cfg)
runner.train()
self.assertTrue(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth')))
for i in range(10):
self.assertFalse(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
# Test filename_tmpl
cfg.default_hooks.checkpoint.filename_tmpl = 'test_{}.pth'
runner = self.build_runner(cfg)
runner.train()
self.assertTrue(osp.isfile(osp.join(cfg.work_dir, 'test_10.pth')))
# Test save_best
cfg.default_hooks.checkpoint.save_best = 'test/acc'
cfg.val_evaluator = dict(type='TriangleMetric', length=11)
cfg.train_cfg.val_interval = 1
runner = self.build_runner(cfg)
runner.train()
self.assertTrue(
osp.isfile(osp.join(cfg.work_dir, 'best_test_acc_test_5.pth')))
# test save published keys
cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict']
runner = self.build_runner(cfg)
runner.train()
ckpt_files = os.listdir(runner.work_dir)
self.assertTrue(
any(re.findall(r'-[\d\w]{8}\.pth', file) for file in ckpt_files))