Refacotr unitest of checkpointhook
parent
4f9995efa0
commit
8fd7791cd5
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import re
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from math import inf
|
||||
|
@ -299,8 +300,8 @@ class CheckpointHook(Hook):
|
|||
def _get_metric_score(self, metrics, key_indicator):
|
||||
eval_res = OrderedDict()
|
||||
if metrics is not None:
|
||||
eval_res.update(metrics)
|
||||
|
||||
for key, value in metrics.items():
|
||||
eval_res[key.partition('/')[-1]] = value
|
||||
if len(eval_res) == 0:
|
||||
warnings.warn(
|
||||
'Since `eval_res` is an empty dict, the behavior to save '
|
||||
|
@ -417,6 +418,7 @@ class CheckpointHook(Hook):
|
|||
'is removed')
|
||||
|
||||
best_ckpt_name = f'best_{key_indicator}_{ckpt_filename}'
|
||||
best_ckpt_name = re.sub(r'(\W)', '_', best_ckpt_name)
|
||||
if len(self.key_indicators) == 1:
|
||||
self.best_ckpt_path = self.file_client.join_path( # type: ignore # noqa: E501
|
||||
self.out_dir, best_ckpt_name)
|
||||
|
|
|
@ -1,56 +1,16 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import copy
|
||||
import os.path as osp
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset
|
||||
from parameterized import parameterized
|
||||
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.fileio import FileClient, LocalBackend
|
||||
from mmengine.hooks import CheckpointHook
|
||||
from mmengine.logging import MessageHub
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.optim import OptimWrapper
|
||||
from mmengine.runner import Runner
|
||||
|
||||
|
||||
class ToyModel(BaseModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(2, 1)
|
||||
|
||||
def forward(self, inputs, data_sample, mode='tensor'):
|
||||
labels = torch.stack(data_sample)
|
||||
inputs = torch.stack(inputs)
|
||||
outputs = self.linear(inputs)
|
||||
if mode == 'tensor':
|
||||
return outputs
|
||||
elif mode == 'loss':
|
||||
loss = (labels - outputs).sum()
|
||||
outputs = dict(loss=loss)
|
||||
return outputs
|
||||
else:
|
||||
return outputs
|
||||
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
METAINFO = dict() # type: ignore
|
||||
data = torch.randn(12, 2)
|
||||
label = torch.ones(12)
|
||||
|
||||
@property
|
||||
def metainfo(self):
|
||||
return self.METAINFO
|
||||
|
||||
def __len__(self):
|
||||
return self.data.size(0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||
from mmengine.registry import METRICS
|
||||
from mmengine.testing import RunnerTestCase
|
||||
|
||||
|
||||
class TriangleMetric(BaseMetric):
|
||||
|
@ -72,428 +32,435 @@ class TriangleMetric(BaseMetric):
|
|||
return dict(acc=acc)
|
||||
|
||||
|
||||
class TestCheckpointHook:
|
||||
class TestCheckpointHook(RunnerTestCase):
|
||||
|
||||
def test_init(self, tmp_path):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
METRICS.register_module(module=TriangleMetric, force=True)
|
||||
|
||||
def tearDown(self):
|
||||
return METRICS.module_dict.clear()
|
||||
|
||||
def test_init(self):
|
||||
# Test file_client_args and backend_args
|
||||
with pytest.warns(
|
||||
with self.assertWarnsRegex(
|
||||
DeprecationWarning,
|
||||
match='"file_client_args" will be deprecated in future'):
|
||||
'"file_client_args" will be deprecated in future'):
|
||||
CheckpointHook(file_client_args={'backend': 'disk'})
|
||||
|
||||
with pytest.raises(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
match='"file_client_args" and "backend_args" cannot be set '
|
||||
'"file_client_args" and "backend_args" cannot be set '
|
||||
'at the same time'):
|
||||
CheckpointHook(
|
||||
file_client_args={'backend': 'disk'},
|
||||
backend_args={'backend': 'local'})
|
||||
|
||||
def test_before_train(self, tmp_path):
|
||||
runner = Mock()
|
||||
work_dir = str(tmp_path)
|
||||
runner.work_dir = work_dir
|
||||
# Test save best
|
||||
CheckpointHook(save_best='acc')
|
||||
|
||||
CheckpointHook(save_best=['acc'])
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, '"save_best" should be'):
|
||||
CheckpointHook(save_best=dict(acc='acc'))
|
||||
|
||||
# error when 'auto' in `save_best` list
|
||||
with self.assertRaisesRegex(AssertionError, 'Only support one'):
|
||||
CheckpointHook(interval=2, save_best=['auto', 'acc'])
|
||||
|
||||
# Test rules
|
||||
|
||||
CheckpointHook(save_best=['acc', 'mAcc'], rule='greater')
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, '"rule" should be a str'):
|
||||
CheckpointHook(save_best=['acc'], rule=1)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'Number of "rule" must be'):
|
||||
CheckpointHook(save_best=['acc'], rule=['greater', 'loss'])
|
||||
|
||||
# Test greater_keys
|
||||
hook = CheckpointHook(greater_keys='acc')
|
||||
self.assertEqual(hook.greater_keys, ('acc', ))
|
||||
|
||||
hook = CheckpointHook(greater_keys=['acc'])
|
||||
self.assertEqual(hook.greater_keys, ['acc'])
|
||||
|
||||
hook = CheckpointHook(
|
||||
interval=2, by_epoch=False, save_best=['acc', 'mIoU'])
|
||||
self.assertEqual(hook.greater_keys, ('acc', 'mIoU'))
|
||||
self.assertEqual(hook.rule, ('greater', 'greater'))
|
||||
|
||||
# Test less keys
|
||||
hook = CheckpointHook(less_keys='loss_cls')
|
||||
self.assertEqual(hook.less_keys, ('loss_cls', ))
|
||||
|
||||
hook = CheckpointHook(less_keys=['loss_cls'])
|
||||
self.assertEqual(hook.less_keys, ['loss_cls'])
|
||||
|
||||
def test_before_train(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
runner = self.build_runner(cfg)
|
||||
# file_client_args is None
|
||||
checkpoint_hook = CheckpointHook()
|
||||
checkpoint_hook.before_train(runner)
|
||||
assert isinstance(checkpoint_hook.file_client, FileClient)
|
||||
assert isinstance(checkpoint_hook.file_backend, LocalBackend)
|
||||
self.assertIsInstance(checkpoint_hook.file_client, FileClient)
|
||||
self.assertIsInstance(checkpoint_hook.file_backend, LocalBackend)
|
||||
|
||||
# file_client_args is not None
|
||||
checkpoint_hook = CheckpointHook(file_client_args={'backend': 'disk'})
|
||||
checkpoint_hook.before_train(runner)
|
||||
assert isinstance(checkpoint_hook.file_client, FileClient)
|
||||
self.assertIsInstance(checkpoint_hook.file_client, FileClient)
|
||||
# file_backend is the alias of file_client
|
||||
assert checkpoint_hook.file_backend is checkpoint_hook.file_client
|
||||
self.assertIs(checkpoint_hook.file_backend,
|
||||
checkpoint_hook.file_client)
|
||||
|
||||
# the out_dir of the checkpoint hook is None
|
||||
checkpoint_hook = CheckpointHook(interval=1, by_epoch=True)
|
||||
checkpoint_hook.before_train(runner)
|
||||
assert checkpoint_hook.out_dir == runner.work_dir
|
||||
self.assertEqual(checkpoint_hook.out_dir, runner.work_dir)
|
||||
|
||||
# the out_dir of the checkpoint hook is not None
|
||||
checkpoint_hook = CheckpointHook(
|
||||
interval=1, by_epoch=True, out_dir='test_dir')
|
||||
checkpoint_hook.before_train(runner)
|
||||
assert checkpoint_hook.out_dir == osp.join(
|
||||
'test_dir', osp.join(osp.basename(work_dir)))
|
||||
self.assertEqual(
|
||||
checkpoint_hook.out_dir,
|
||||
osp.join('test_dir', osp.join(osp.basename(cfg.work_dir))))
|
||||
|
||||
runner.message_hub = MessageHub.get_instance('test_before_train')
|
||||
# no 'best_ckpt_path' in runtime_info
|
||||
# If `save_best` is a list of string, the path to save the best
|
||||
# checkpoint will be defined in attribute `best_ckpt_path_dict`.
|
||||
checkpoint_hook = CheckpointHook(interval=1, save_best=['acc', 'mIoU'])
|
||||
checkpoint_hook.before_train(runner)
|
||||
assert checkpoint_hook.best_ckpt_path_dict == dict(acc=None, mIoU=None)
|
||||
assert not hasattr(checkpoint_hook, 'best_ckpt_path')
|
||||
self.assertEqual(checkpoint_hook.best_ckpt_path_dict,
|
||||
dict(acc=None, mIoU=None))
|
||||
self.assertFalse(hasattr(checkpoint_hook, 'best_ckpt_path'))
|
||||
|
||||
# only one 'best_ckpt_path' in runtime_info
|
||||
# Resume 'best_ckpt_path' from message_hub
|
||||
runner.message_hub.update_info('best_ckpt_acc', 'best_acc')
|
||||
checkpoint_hook.before_train(runner)
|
||||
assert checkpoint_hook.best_ckpt_path_dict == dict(
|
||||
acc='best_acc', mIoU=None)
|
||||
self.assertEqual(checkpoint_hook.best_ckpt_path_dict,
|
||||
dict(acc='best_acc', mIoU=None))
|
||||
|
||||
# no 'best_ckpt_path' in runtime_info
|
||||
# If `save_best` is a string, the path to save best ckpt will be
|
||||
# defined in attribute `best_ckpt_path`
|
||||
checkpoint_hook = CheckpointHook(interval=1, save_best='acc')
|
||||
checkpoint_hook.before_train(runner)
|
||||
assert checkpoint_hook.best_ckpt_path is None
|
||||
assert not hasattr(checkpoint_hook, 'best_ckpt_path_dict')
|
||||
self.assertIsNone(checkpoint_hook.best_ckpt_path)
|
||||
self.assertFalse(hasattr(checkpoint_hook, 'best_ckpt_path_dict'))
|
||||
|
||||
# 'best_ckpt_path' in runtime_info
|
||||
# Resume `best_ckpt` path from message_hub
|
||||
runner.message_hub.update_info('best_ckpt', 'best_ckpt')
|
||||
checkpoint_hook.before_train(runner)
|
||||
assert checkpoint_hook.best_ckpt_path == 'best_ckpt'
|
||||
self.assertEqual(checkpoint_hook.best_ckpt_path, 'best_ckpt')
|
||||
|
||||
def test_after_val_epoch(self, tmp_path):
|
||||
runner = Mock()
|
||||
runner.work_dir = tmp_path
|
||||
runner.epoch = 9
|
||||
runner.model = Mock()
|
||||
runner.message_hub = MessageHub.get_instance('test_after_val_epoch')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# key_indicator must be valid when rule_map is None
|
||||
CheckpointHook(interval=2, by_epoch=True, save_best='unsupport')
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
# rule must be in keys of rule_map
|
||||
CheckpointHook(
|
||||
interval=2, by_epoch=True, save_best='auto', rule='unsupport')
|
||||
|
||||
# if eval_res is an empty dict, print a warning information
|
||||
with pytest.warns(UserWarning) as record_warnings:
|
||||
eval_hook = CheckpointHook(
|
||||
interval=2, by_epoch=True, save_best='auto')
|
||||
eval_hook._get_metric_score(None, None)
|
||||
# Since there will be many warnings thrown, we just need to check
|
||||
# if the expected exceptions are thrown
|
||||
expected_message = (
|
||||
'Since `eval_res` is an empty dict, the behavior to '
|
||||
'save the best checkpoint will be skipped in this '
|
||||
'evaluation.')
|
||||
for warning in record_warnings:
|
||||
if str(warning.message) == expected_message:
|
||||
break
|
||||
else:
|
||||
assert False
|
||||
|
||||
# test error when number of rules and metrics are not same
|
||||
with pytest.raises(AssertionError) as assert_error:
|
||||
CheckpointHook(
|
||||
interval=1,
|
||||
save_best=['mIoU', 'acc'],
|
||||
rule=['greater', 'greater', 'less'],
|
||||
by_epoch=True)
|
||||
error_message = ('Number of "rule" must be 1 or the same as number of '
|
||||
'"save_best", but got 3.')
|
||||
assert error_message in str(assert_error.value)
|
||||
def test_after_val_epoch(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train_loop._epoch = 9
|
||||
|
||||
# if save_best is None,no best_ckpt meta should be stored
|
||||
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best=None)
|
||||
eval_hook.before_train(runner)
|
||||
eval_hook.after_val_epoch(runner, None)
|
||||
assert 'best_score' not in runner.message_hub.runtime_info
|
||||
assert 'best_ckpt' not in runner.message_hub.runtime_info
|
||||
ckpt_hook = CheckpointHook(interval=2, by_epoch=True, save_best=None)
|
||||
ckpt_hook.before_train(runner)
|
||||
ckpt_hook.after_val_epoch(runner, None)
|
||||
self.assertNotIn('best_score', runner.message_hub.runtime_info)
|
||||
self.assertNotIn('best_ckpt', runner.message_hub.runtime_info)
|
||||
|
||||
# when `save_best` is set to `auto`, first metric will be used.
|
||||
metrics = {'acc': 0.5, 'map': 0.3}
|
||||
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='auto')
|
||||
eval_hook.before_train(runner)
|
||||
eval_hook.after_val_epoch(runner, metrics)
|
||||
ckpt_hook = CheckpointHook(interval=2, by_epoch=True, save_best='auto')
|
||||
ckpt_hook.before_train(runner)
|
||||
ckpt_hook.after_val_epoch(runner, metrics)
|
||||
best_ckpt_name = 'best_acc_epoch_9.pth'
|
||||
best_ckpt_path = eval_hook.file_client.join_path(
|
||||
eval_hook.out_dir, best_ckpt_name)
|
||||
assert eval_hook.key_indicators == ['acc']
|
||||
assert eval_hook.rules == ['greater']
|
||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_score') == 0.5
|
||||
assert 'best_ckpt' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
|
||||
best_ckpt_path = ckpt_hook.file_client.join_path(
|
||||
ckpt_hook.out_dir, best_ckpt_name)
|
||||
self.assertEqual(ckpt_hook.key_indicators, ['acc'])
|
||||
self.assertEqual(ckpt_hook.rules, ['greater'])
|
||||
self.assertEqual(runner.message_hub.get_info('best_score'), 0.5)
|
||||
self.assertEqual(
|
||||
runner.message_hub.get_info('best_ckpt'), best_ckpt_path)
|
||||
|
||||
# # when `save_best` is set to `acc`, it should update greater value
|
||||
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='acc')
|
||||
eval_hook.before_train(runner)
|
||||
ckpt_hook = CheckpointHook(interval=2, by_epoch=True, save_best='acc')
|
||||
ckpt_hook.before_train(runner)
|
||||
metrics['acc'] = 0.8
|
||||
eval_hook.after_val_epoch(runner, metrics)
|
||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_score') == 0.8
|
||||
ckpt_hook.after_val_epoch(runner, metrics)
|
||||
self.assertEqual(runner.message_hub.get_info('best_score'), 0.8)
|
||||
|
||||
# # when `save_best` is set to `loss`, it should update less value
|
||||
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='loss')
|
||||
eval_hook.before_train(runner)
|
||||
ckpt_hook = CheckpointHook(interval=2, by_epoch=True, save_best='loss')
|
||||
ckpt_hook.before_train(runner)
|
||||
metrics['loss'] = 0.8
|
||||
eval_hook.after_val_epoch(runner, metrics)
|
||||
ckpt_hook.after_val_epoch(runner, metrics)
|
||||
metrics['loss'] = 0.5
|
||||
eval_hook.after_val_epoch(runner, metrics)
|
||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_score') == 0.5
|
||||
ckpt_hook.after_val_epoch(runner, metrics)
|
||||
self.assertEqual(runner.message_hub.get_info('best_score'), 0.5)
|
||||
|
||||
# when `rule` is set to `less`,then it should update less value
|
||||
# no matter what `save_best` is
|
||||
eval_hook = CheckpointHook(
|
||||
ckpt_hook = CheckpointHook(
|
||||
interval=2, by_epoch=True, save_best='acc', rule='less')
|
||||
eval_hook.before_train(runner)
|
||||
ckpt_hook.before_train(runner)
|
||||
metrics['acc'] = 0.3
|
||||
eval_hook.after_val_epoch(runner, metrics)
|
||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_score') == 0.3
|
||||
ckpt_hook.after_val_epoch(runner, metrics)
|
||||
self.assertEqual(runner.message_hub.get_info('best_score'), 0.3)
|
||||
|
||||
# # when `rule` is set to `greater`,then it should update greater value
|
||||
# # no matter what `save_best` is
|
||||
eval_hook = CheckpointHook(
|
||||
ckpt_hook = CheckpointHook(
|
||||
interval=2, by_epoch=True, save_best='loss', rule='greater')
|
||||
eval_hook.before_train(runner)
|
||||
ckpt_hook.before_train(runner)
|
||||
metrics['loss'] = 1.0
|
||||
eval_hook.after_val_epoch(runner, metrics)
|
||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_score') == 1.0
|
||||
ckpt_hook.after_val_epoch(runner, metrics)
|
||||
self.assertEqual(runner.message_hub.get_info('best_score'), 1.0)
|
||||
|
||||
# test multi `save_best` with one rule
|
||||
eval_hook = CheckpointHook(
|
||||
ckpt_hook = CheckpointHook(
|
||||
interval=2, save_best=['acc', 'mIoU'], rule='greater')
|
||||
assert eval_hook.key_indicators == ['acc', 'mIoU']
|
||||
assert eval_hook.rules == ['greater', 'greater']
|
||||
self.assertEqual(ckpt_hook.key_indicators, ['acc', 'mIoU'])
|
||||
self.assertEqual(ckpt_hook.rules, ['greater', 'greater'])
|
||||
|
||||
# test multi `save_best` with multi rules
|
||||
eval_hook = CheckpointHook(
|
||||
ckpt_hook = CheckpointHook(
|
||||
interval=2, save_best=['FID', 'IS'], rule=['less', 'greater'])
|
||||
assert eval_hook.key_indicators == ['FID', 'IS']
|
||||
assert eval_hook.rules == ['less', 'greater']
|
||||
self.assertEqual(ckpt_hook.key_indicators, ['FID', 'IS'])
|
||||
self.assertEqual(ckpt_hook.rules, ['less', 'greater'])
|
||||
|
||||
# test multi `save_best` with default rule
|
||||
eval_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU'])
|
||||
assert eval_hook.key_indicators == ['acc', 'mIoU']
|
||||
assert eval_hook.rules == ['greater', 'greater']
|
||||
ckpt_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU'])
|
||||
self.assertEqual(ckpt_hook.key_indicators, ['acc', 'mIoU'])
|
||||
self.assertEqual(ckpt_hook.rules, ['greater', 'greater'])
|
||||
runner.message_hub = MessageHub.get_instance(
|
||||
'test_after_val_epoch_save_multi_best')
|
||||
eval_hook.before_train(runner)
|
||||
ckpt_hook.before_train(runner)
|
||||
metrics = dict(acc=0.5, mIoU=0.6)
|
||||
eval_hook.after_val_epoch(runner, metrics)
|
||||
ckpt_hook.after_val_epoch(runner, metrics)
|
||||
best_acc_name = 'best_acc_epoch_9.pth'
|
||||
best_acc_path = eval_hook.file_client.join_path(
|
||||
eval_hook.out_dir, best_acc_name)
|
||||
best_acc_path = ckpt_hook.file_client.join_path(
|
||||
ckpt_hook.out_dir, best_acc_name)
|
||||
best_mIoU_name = 'best_mIoU_epoch_9.pth'
|
||||
best_mIoU_path = eval_hook.file_client.join_path(
|
||||
eval_hook.out_dir, best_mIoU_name)
|
||||
assert 'best_score_acc' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_score_acc') == 0.5
|
||||
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_score_mIoU') == 0.6
|
||||
assert 'best_ckpt_acc' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_ckpt_acc') == best_acc_path
|
||||
assert 'best_ckpt_mIoU' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_ckpt_mIoU') == best_mIoU_path
|
||||
best_mIoU_path = ckpt_hook.file_client.join_path(
|
||||
ckpt_hook.out_dir, best_mIoU_name)
|
||||
self.assertEqual(runner.message_hub.get_info('best_score_acc'), 0.5)
|
||||
|
||||
self.assertEqual(runner.message_hub.get_info('best_score_mIoU'), 0.6)
|
||||
|
||||
self.assertEqual(
|
||||
runner.message_hub.get_info('best_ckpt_acc'), best_acc_path)
|
||||
|
||||
self.assertEqual(
|
||||
runner.message_hub.get_info('best_ckpt_mIoU'), best_mIoU_path)
|
||||
|
||||
# test behavior when by_epoch is False
|
||||
runner = Mock()
|
||||
runner.work_dir = tmp_path
|
||||
runner.iter = 9
|
||||
runner.model = Mock()
|
||||
runner.message_hub = MessageHub.get_instance(
|
||||
'test_after_val_epoch_by_epoch_is_false')
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train_loop._iter = 9
|
||||
|
||||
# check best ckpt name and best score
|
||||
metrics = {'acc': 0.5, 'map': 0.3}
|
||||
eval_hook = CheckpointHook(
|
||||
ckpt_hook = CheckpointHook(
|
||||
interval=2, by_epoch=False, save_best='acc', rule='greater')
|
||||
eval_hook.before_train(runner)
|
||||
eval_hook.after_val_epoch(runner, metrics)
|
||||
assert eval_hook.key_indicators == ['acc']
|
||||
assert eval_hook.rules == ['greater']
|
||||
ckpt_hook.before_train(runner)
|
||||
ckpt_hook.after_val_epoch(runner, metrics)
|
||||
self.assertEqual(ckpt_hook.key_indicators, ['acc'])
|
||||
self.assertEqual(ckpt_hook.rules, ['greater'])
|
||||
best_ckpt_name = 'best_acc_iter_9.pth'
|
||||
best_ckpt_path = eval_hook.file_client.join_path(
|
||||
eval_hook.out_dir, best_ckpt_name)
|
||||
assert 'best_ckpt' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
|
||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_score') == 0.5
|
||||
best_ckpt_path = ckpt_hook.file_client.join_path(
|
||||
ckpt_hook.out_dir, best_ckpt_name)
|
||||
|
||||
self.assertEqual(
|
||||
runner.message_hub.get_info('best_ckpt'), best_ckpt_path)
|
||||
self.assertEqual(runner.message_hub.get_info('best_score'), 0.5)
|
||||
|
||||
# check best score updating
|
||||
metrics['acc'] = 0.666
|
||||
eval_hook.after_val_epoch(runner, metrics)
|
||||
ckpt_hook.after_val_epoch(runner, metrics)
|
||||
best_ckpt_name = 'best_acc_iter_9.pth'
|
||||
best_ckpt_path = eval_hook.file_client.join_path(
|
||||
eval_hook.out_dir, best_ckpt_name)
|
||||
assert 'best_ckpt' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
|
||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_score') == 0.666
|
||||
# error when 'auto' in `save_best` list
|
||||
with pytest.raises(AssertionError):
|
||||
CheckpointHook(interval=2, save_best=['auto', 'acc'])
|
||||
# error when one `save_best` with multi `rule`
|
||||
with pytest.raises(AssertionError):
|
||||
CheckpointHook(
|
||||
interval=2, save_best='acc', rule=['greater', 'less'])
|
||||
best_ckpt_path = ckpt_hook.file_client.join_path(
|
||||
ckpt_hook.out_dir, best_ckpt_name)
|
||||
self.assertEqual(
|
||||
runner.message_hub.get_info('best_ckpt'), best_ckpt_path)
|
||||
|
||||
self.assertEqual(runner.message_hub.get_info('best_score'), 0.666)
|
||||
|
||||
# check best checkpoint name with `by_epoch` is False
|
||||
eval_hook = CheckpointHook(
|
||||
ckpt_hook = CheckpointHook(
|
||||
interval=2, by_epoch=False, save_best=['acc', 'mIoU'])
|
||||
assert eval_hook.key_indicators == ['acc', 'mIoU']
|
||||
assert eval_hook.rules == ['greater', 'greater']
|
||||
runner.message_hub = MessageHub.get_instance(
|
||||
'test_after_val_epoch_save_multi_best_by_epoch_is_false')
|
||||
eval_hook.before_train(runner)
|
||||
ckpt_hook.before_train(runner)
|
||||
metrics = dict(acc=0.5, mIoU=0.6)
|
||||
eval_hook.after_val_epoch(runner, metrics)
|
||||
ckpt_hook.after_val_epoch(runner, metrics)
|
||||
best_acc_name = 'best_acc_iter_9.pth'
|
||||
best_acc_path = eval_hook.file_client.join_path(
|
||||
eval_hook.out_dir, best_acc_name)
|
||||
best_acc_path = ckpt_hook.file_client.join_path(
|
||||
ckpt_hook.out_dir, best_acc_name)
|
||||
best_mIoU_name = 'best_mIoU_iter_9.pth'
|
||||
best_mIoU_path = eval_hook.file_client.join_path(
|
||||
eval_hook.out_dir, best_mIoU_name)
|
||||
assert 'best_score_acc' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_score_acc') == 0.5
|
||||
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_score_mIoU') == 0.6
|
||||
assert 'best_ckpt_acc' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_ckpt_acc') == best_acc_path
|
||||
assert 'best_ckpt_mIoU' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_ckpt_mIoU') == best_mIoU_path
|
||||
best_mIoU_path = ckpt_hook.file_client.join_path(
|
||||
ckpt_hook.out_dir, best_mIoU_name)
|
||||
|
||||
# after_val_epoch should not save last_checkpoint.
|
||||
assert not osp.isfile(osp.join(runner.work_dir, 'last_checkpoint'))
|
||||
self.assertEqual(runner.message_hub.get_info('best_score_acc'), 0.5)
|
||||
self.assertEqual(runner.message_hub.get_info('best_score_mIoU'), 0.6)
|
||||
self.assertEqual(
|
||||
runner.message_hub.get_info('best_ckpt_acc'), best_acc_path)
|
||||
self.assertEqual(
|
||||
runner.message_hub.get_info('best_ckpt_mIoU'), best_mIoU_path)
|
||||
|
||||
def test_after_train_epoch(self, tmp_path):
|
||||
runner = Mock()
|
||||
work_dir = str(tmp_path)
|
||||
runner.work_dir = tmp_path
|
||||
runner.epoch = 9
|
||||
runner.model = Mock()
|
||||
runner.message_hub = MessageHub.get_instance('test_after_train_epoch')
|
||||
# after_val_epoch should not save last_checkpoint
|
||||
self.assertFalse(
|
||||
osp.isfile(osp.join(runner.work_dir, 'last_checkpoint')))
|
||||
|
||||
def test_after_train_epoch(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train_loop._epoch = 9
|
||||
runner.optim_wrapper = runner.build_optim_wrapper(runner.optim_wrapper)
|
||||
|
||||
# by epoch is True
|
||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
|
||||
checkpoint_hook.before_train(runner)
|
||||
checkpoint_hook.after_train_epoch(runner)
|
||||
assert (runner.epoch + 1) % 2 == 0
|
||||
assert 'last_ckpt' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('last_ckpt') == \
|
||||
osp.join(work_dir, 'epoch_10.pth')
|
||||
last_ckpt_path = osp.join(work_dir, 'last_checkpoint')
|
||||
assert osp.isfile(last_ckpt_path)
|
||||
self.assertEqual((runner.epoch + 1) % 2, 0)
|
||||
self.assertEqual(
|
||||
runner.message_hub.get_info('last_ckpt'),
|
||||
osp.join(cfg.work_dir, 'epoch_10.pth'))
|
||||
|
||||
last_ckpt_path = osp.join(cfg.work_dir, 'last_checkpoint')
|
||||
self.assertTrue(osp.isfile(last_ckpt_path))
|
||||
|
||||
with open(last_ckpt_path) as f:
|
||||
filepath = f.read()
|
||||
assert filepath == osp.join(work_dir, 'epoch_10.pth')
|
||||
self.assertEqual(filepath, osp.join(cfg.work_dir, 'epoch_10.pth'))
|
||||
|
||||
# epoch can not be evenly divided by 2
|
||||
runner.epoch = 10
|
||||
runner.train_loop._epoch = 10
|
||||
checkpoint_hook.after_train_epoch(runner)
|
||||
assert 'last_ckpt' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('last_ckpt') == \
|
||||
osp.join(work_dir, 'epoch_10.pth')
|
||||
self.assertEqual(
|
||||
runner.message_hub.get_info('last_ckpt'),
|
||||
osp.join(cfg.work_dir, 'epoch_10.pth'))
|
||||
runner.message_hub.runtime_info.clear()
|
||||
|
||||
# by epoch is False
|
||||
runner.epoch = 9
|
||||
runner.message_hub = MessageHub.get_instance('test_after_train_epoch1')
|
||||
runner.train_loop._epoch = 9
|
||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
|
||||
checkpoint_hook.before_train(runner)
|
||||
checkpoint_hook.after_train_epoch(runner)
|
||||
assert 'last_ckpt' not in runner.message_hub.runtime_info
|
||||
|
||||
# # max_keep_ckpts > 0
|
||||
runner.work_dir = work_dir
|
||||
os.system(f'touch {work_dir}/epoch_8.pth')
|
||||
checkpoint_hook = CheckpointHook(
|
||||
interval=2, by_epoch=True, max_keep_ckpts=1)
|
||||
checkpoint_hook.before_train(runner)
|
||||
checkpoint_hook.after_train_epoch(runner)
|
||||
assert (runner.epoch + 1) % 2 == 0
|
||||
assert not os.path.exists(f'{work_dir}/epoch_8.pth')
|
||||
|
||||
# save_checkpoint of runner should be called with expected arguments
|
||||
runner = Mock()
|
||||
work_dir = str(tmp_path)
|
||||
runner.work_dir = tmp_path
|
||||
runner.epoch = 1
|
||||
runner.message_hub = MessageHub.get_instance('test_after_train_epoch2')
|
||||
|
||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
|
||||
checkpoint_hook.before_train(runner)
|
||||
checkpoint_hook.after_train_epoch(runner)
|
||||
|
||||
runner.save_checkpoint.assert_called_once_with(
|
||||
runner.work_dir,
|
||||
'epoch_2.pth',
|
||||
None,
|
||||
backend_args=None,
|
||||
by_epoch=True,
|
||||
save_optimizer=True,
|
||||
save_param_scheduler=True)
|
||||
|
||||
def test_after_train_iter(self, tmp_path):
|
||||
work_dir = str(tmp_path)
|
||||
runner = Mock()
|
||||
runner.work_dir = str(work_dir)
|
||||
runner.iter = 9
|
||||
batch_idx = 9
|
||||
runner.model = Mock()
|
||||
runner.message_hub = MessageHub.get_instance('test_after_train_iter')
|
||||
self.assertNotIn('last_ckpt', runner.message_hub.runtime_info)
|
||||
runner.message_hub.runtime_info.clear()
|
||||
|
||||
def test_after_train_iter(self):
|
||||
# by epoch is True
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train_loop._iter = 9
|
||||
runner.optim_wrapper = runner.build_optim_wrapper(runner.optim_wrapper)
|
||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
|
||||
checkpoint_hook.before_train(runner)
|
||||
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
assert 'last_ckpt' not in runner.message_hub.runtime_info
|
||||
checkpoint_hook.after_train_iter(runner, batch_idx=9)
|
||||
self.assertNotIn('last_ckpt', runner.message_hub.runtime_info)
|
||||
|
||||
# by epoch is False
|
||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
|
||||
checkpoint_hook.before_train(runner)
|
||||
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
assert (runner.iter + 1) % 2 == 0
|
||||
assert 'last_ckpt' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('last_ckpt') == \
|
||||
osp.join(work_dir, 'iter_10.pth')
|
||||
checkpoint_hook.after_train_iter(runner, batch_idx=9)
|
||||
self.assertIn('last_ckpt', runner.message_hub.runtime_info)
|
||||
self.assertEqual(
|
||||
runner.message_hub.get_info('last_ckpt'),
|
||||
osp.join(cfg.work_dir, 'iter_10.pth'))
|
||||
|
||||
# epoch can not be evenly divided by 2
|
||||
runner.iter = 10
|
||||
runner.train_loop._iter = 10
|
||||
checkpoint_hook.after_train_epoch(runner)
|
||||
assert 'last_ckpt' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('last_ckpt') == \
|
||||
osp.join(work_dir, 'iter_10.pth')
|
||||
self.assertEqual(
|
||||
runner.message_hub.get_info('last_ckpt'),
|
||||
osp.join(cfg.work_dir, 'iter_10.pth'))
|
||||
|
||||
# max_keep_ckpts > 0
|
||||
runner.iter = 9
|
||||
runner.work_dir = work_dir
|
||||
os.system(f'touch {osp.join(work_dir, "iter_8.pth")}')
|
||||
checkpoint_hook = CheckpointHook(
|
||||
interval=2, by_epoch=False, max_keep_ckpts=1)
|
||||
checkpoint_hook.before_train(runner)
|
||||
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
assert not os.path.exists(f'{work_dir}/iter_8.pth')
|
||||
|
||||
def test_with_runner(self, tmp_path):
|
||||
max_epoch = 10
|
||||
work_dir = osp.join(str(tmp_path), 'runner_test')
|
||||
tmpl = '{}.pth'
|
||||
save_interval = 2
|
||||
@parameterized.expand([['iter'], ['epoch']])
|
||||
def test_with_runner(self, training_type):
|
||||
# Test interval in epoch based training
|
||||
save_iterval = 2
|
||||
cfg = copy.deepcopy(getattr(self, f'{training_type}_based_cfg'))
|
||||
setattr(cfg.train_cfg, f'max_{training_type}s', 11)
|
||||
checkpoint_cfg = dict(
|
||||
type='CheckpointHook',
|
||||
interval=save_interval,
|
||||
filename_tmpl=tmpl,
|
||||
by_epoch=True)
|
||||
runner = Runner(
|
||||
model=ToyModel(),
|
||||
work_dir=work_dir,
|
||||
train_dataloader=dict(
|
||||
dataset=DummyDataset(),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_dataloader=dict(
|
||||
dataset=DummyDataset(),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_evaluator=dict(type=TriangleMetric, length=max_epoch),
|
||||
optim_wrapper=OptimWrapper(
|
||||
torch.optim.Adam(ToyModel().parameters())),
|
||||
train_cfg=dict(
|
||||
by_epoch=True, max_epochs=max_epoch, val_interval=1),
|
||||
val_cfg=dict(),
|
||||
default_hooks=dict(checkpoint=checkpoint_cfg))
|
||||
interval=save_iterval,
|
||||
by_epoch=training_type == 'epoch')
|
||||
cfg.default_hooks = dict(checkpoint=checkpoint_cfg)
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
for epoch in range(max_epoch):
|
||||
if epoch % save_interval != 0 or epoch == 0:
|
||||
continue
|
||||
path = osp.join(work_dir, tmpl.format(epoch))
|
||||
assert osp.isfile(path=path)
|
||||
|
||||
for i in range(1, 11):
|
||||
if i == 0:
|
||||
self.assertFalse(
|
||||
osp.isfile(
|
||||
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
||||
if i % 2 == 0:
|
||||
self.assertTrue(
|
||||
osp.isfile(
|
||||
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
||||
|
||||
self.assertTrue(
|
||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
|
||||
|
||||
# Test save_optimizer=False
|
||||
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
|
||||
self.assertIn('optimizer', ckpt)
|
||||
cfg.default_hooks.checkpoint.save_optimizer = False
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
|
||||
self.assertNotIn('optimizer', ckpt)
|
||||
|
||||
# Test save_param_scheduler=False
|
||||
cfg.param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=0.1,
|
||||
begin=0,
|
||||
end=500,
|
||||
by_epoch=training_type == 'epoch')
|
||||
]
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
|
||||
self.assertIn('param_schedulers', ckpt)
|
||||
|
||||
cfg.default_hooks.checkpoint.save_param_scheduler = False
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
|
||||
self.assertNotIn('param_schedulers', ckpt)
|
||||
|
||||
# Test out_dir
|
||||
out_dir = osp.join(self.temp_dir.name, 'out_dir')
|
||||
cfg.default_hooks.checkpoint.out_dir = out_dir
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
self.assertTrue(
|
||||
osp.isfile(
|
||||
osp.join(out_dir, osp.basename(cfg.work_dir),
|
||||
f'{training_type}_11.pth')))
|
||||
|
||||
# Test max_keep_ckpts.
|
||||
del cfg.default_hooks.checkpoint.out_dir
|
||||
cfg.default_hooks.checkpoint.max_keep_ckpts = 1
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
self.assertTrue(
|
||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth')))
|
||||
|
||||
for i in range(10):
|
||||
self.assertFalse(
|
||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
||||
|
||||
# Test filename_tmpl
|
||||
cfg.default_hooks.checkpoint.filename_tmpl = 'test_{}.pth'
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
self.assertTrue(osp.isfile(osp.join(cfg.work_dir, 'test_10.pth')))
|
||||
|
||||
# Test save_best
|
||||
cfg.default_hooks.checkpoint.save_best = 'acc'
|
||||
cfg.val_evaluator = dict(type='TriangleMetric', length=11)
|
||||
cfg.train_cfg.val_interval = 1
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
self.assertTrue(
|
||||
osp.isfile(osp.join(cfg.work_dir, 'best_acc_test_5_pth')))
|
||||
|
|
Loading…
Reference in New Issue