mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
262 lines
10 KiB
Python
262 lines
10 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os
|
|
import os.path as osp
|
|
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
|
|
from mmengine.hooks import CheckpointHook
|
|
from mmengine.logging import MessageHub
|
|
|
|
|
|
class MockPetrel:
|
|
|
|
_allow_symlink = False
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
@property
|
|
def name(self):
|
|
return self.__class__.__name__
|
|
|
|
@property
|
|
def allow_symlink(self):
|
|
return self._allow_symlink
|
|
|
|
|
|
prefix_to_backends = {'s3': MockPetrel}
|
|
|
|
|
|
class TestCheckpointHook:
|
|
|
|
@patch('mmengine.fileio.file_client.FileClient._prefix_to_backends',
|
|
prefix_to_backends)
|
|
def test_before_train(self, tmp_path):
|
|
runner = Mock()
|
|
work_dir = str(tmp_path)
|
|
runner.work_dir = work_dir
|
|
|
|
# the out_dir of the checkpoint hook is None
|
|
checkpoint_hook = CheckpointHook(interval=1, by_epoch=True)
|
|
checkpoint_hook.before_train(runner)
|
|
assert checkpoint_hook.out_dir == runner.work_dir
|
|
|
|
# the out_dir of the checkpoint hook is not None
|
|
checkpoint_hook = CheckpointHook(
|
|
interval=1, by_epoch=True, out_dir='test_dir')
|
|
checkpoint_hook.before_train(runner)
|
|
assert checkpoint_hook.out_dir == (
|
|
f'test_dir/{osp.basename(work_dir)}')
|
|
|
|
def test_after_val_epoch(self, tmp_path):
|
|
runner = Mock()
|
|
runner.work_dir = tmp_path
|
|
runner.epoch = 9
|
|
runner.model = Mock()
|
|
runner.message_hub = MessageHub.get_instance('test_after_val_epoch')
|
|
|
|
with pytest.raises(ValueError):
|
|
# key_indicator must be valid when rule_map is None
|
|
CheckpointHook(interval=2, by_epoch=True, save_best='unsupport')
|
|
|
|
with pytest.raises(KeyError):
|
|
# rule must be in keys of rule_map
|
|
CheckpointHook(
|
|
interval=2, by_epoch=True, save_best='auto', rule='unsupport')
|
|
|
|
# if eval_res is an empty dict, print a warning information
|
|
with pytest.warns(UserWarning) as record_warnings:
|
|
eval_hook = CheckpointHook(
|
|
interval=2, by_epoch=True, save_best='auto')
|
|
eval_hook._get_metric_score(None)
|
|
# Since there will be many warnings thrown, we just need to check
|
|
# if the expected exceptions are thrown
|
|
expected_message = (
|
|
'Since `eval_res` is an empty dict, the behavior to '
|
|
'save the best checkpoint will be skipped in this '
|
|
'evaluation.')
|
|
for warning in record_warnings:
|
|
if str(warning.message) == expected_message:
|
|
break
|
|
else:
|
|
assert False
|
|
|
|
# if save_best is None,no best_ckpt meta should be stored
|
|
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best=None)
|
|
eval_hook.before_train(runner)
|
|
eval_hook.after_val_epoch(runner, None)
|
|
assert 'best_score' not in runner.message_hub.runtime_info
|
|
assert 'best_ckpt' not in runner.message_hub.runtime_info
|
|
|
|
# when `save_best` is set to `auto`, first metric will be used.
|
|
metrics = {'acc': 0.5, 'map': 0.3}
|
|
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='auto')
|
|
eval_hook.before_train(runner)
|
|
eval_hook.after_val_epoch(runner, metrics)
|
|
best_ckpt_name = 'best_acc_epoch_10.pth'
|
|
best_ckpt_path = eval_hook.file_client.join_path(
|
|
eval_hook.out_dir, best_ckpt_name)
|
|
assert eval_hook.key_indicator == 'acc'
|
|
assert eval_hook.rule == 'greater'
|
|
assert 'best_score' in runner.message_hub.runtime_info and \
|
|
runner.message_hub.get_info('best_score') == 0.5
|
|
assert 'best_ckpt' in runner.message_hub.runtime_info and \
|
|
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
|
|
|
|
# # when `save_best` is set to `acc`, it should update greater value
|
|
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='acc')
|
|
eval_hook.before_train(runner)
|
|
metrics['acc'] = 0.8
|
|
eval_hook.after_val_epoch(runner, metrics)
|
|
assert 'best_score' in runner.message_hub.runtime_info and \
|
|
runner.message_hub.get_info('best_score') == 0.8
|
|
|
|
# # when `save_best` is set to `loss`, it should update less value
|
|
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='loss')
|
|
eval_hook.before_train(runner)
|
|
metrics['loss'] = 0.8
|
|
eval_hook.after_val_epoch(runner, metrics)
|
|
metrics['loss'] = 0.5
|
|
eval_hook.after_val_epoch(runner, metrics)
|
|
assert 'best_score' in runner.message_hub.runtime_info and \
|
|
runner.message_hub.get_info('best_score') == 0.5
|
|
|
|
# when `rule` is set to `less`,then it should update less value
|
|
# no matter what `save_best` is
|
|
eval_hook = CheckpointHook(
|
|
interval=2, by_epoch=True, save_best='acc', rule='less')
|
|
eval_hook.before_train(runner)
|
|
metrics['acc'] = 0.3
|
|
eval_hook.after_val_epoch(runner, metrics)
|
|
assert 'best_score' in runner.message_hub.runtime_info and \
|
|
runner.message_hub.get_info('best_score') == 0.3
|
|
|
|
# # when `rule` is set to `greater`,then it should update greater value
|
|
# # no matter what `save_best` is
|
|
eval_hook = CheckpointHook(
|
|
interval=2, by_epoch=True, save_best='loss', rule='greater')
|
|
eval_hook.before_train(runner)
|
|
metrics['loss'] = 1.0
|
|
eval_hook.after_val_epoch(runner, metrics)
|
|
assert 'best_score' in runner.message_hub.runtime_info and \
|
|
runner.message_hub.get_info('best_score') == 1.0
|
|
|
|
# test behavior when by_epoch is False
|
|
runner = Mock()
|
|
runner.work_dir = tmp_path
|
|
runner.iter = 9
|
|
runner.model = Mock()
|
|
runner.message_hub = MessageHub.get_instance(
|
|
'test_after_val_epoch_by_epoch_is_false')
|
|
|
|
# check best ckpt name and best score
|
|
metrics = {'acc': 0.5, 'map': 0.3}
|
|
eval_hook = CheckpointHook(
|
|
interval=2, by_epoch=False, save_best='acc', rule='greater')
|
|
eval_hook.before_train(runner)
|
|
eval_hook.after_val_epoch(runner, metrics)
|
|
assert eval_hook.key_indicator == 'acc'
|
|
assert eval_hook.rule == 'greater'
|
|
best_ckpt_name = 'best_acc_iter_10.pth'
|
|
best_ckpt_path = eval_hook.file_client.join_path(
|
|
eval_hook.out_dir, best_ckpt_name)
|
|
assert 'best_ckpt' in runner.message_hub.runtime_info and \
|
|
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
|
|
assert 'best_score' in runner.message_hub.runtime_info and \
|
|
runner.message_hub.get_info('best_score') == 0.5
|
|
|
|
# check best score updating
|
|
metrics['acc'] = 0.666
|
|
eval_hook.after_val_epoch(runner, metrics)
|
|
best_ckpt_name = 'best_acc_iter_10.pth'
|
|
best_ckpt_path = eval_hook.file_client.join_path(
|
|
eval_hook.out_dir, best_ckpt_name)
|
|
assert 'best_ckpt' in runner.message_hub.runtime_info and \
|
|
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
|
|
assert 'best_score' in runner.message_hub.runtime_info and \
|
|
runner.message_hub.get_info('best_score') == 0.666
|
|
|
|
def test_after_train_epoch(self, tmp_path):
|
|
runner = Mock()
|
|
work_dir = str(tmp_path)
|
|
runner.work_dir = tmp_path
|
|
runner.epoch = 9
|
|
runner.model = Mock()
|
|
runner.message_hub = MessageHub.get_instance('test_after_train_epoch')
|
|
|
|
# by epoch is True
|
|
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
|
|
checkpoint_hook.before_train(runner)
|
|
checkpoint_hook.after_train_epoch(runner)
|
|
assert (runner.epoch + 1) % 2 == 0
|
|
assert 'last_ckpt' in runner.message_hub.runtime_info and \
|
|
runner.message_hub.get_info('last_ckpt') == (
|
|
f'{work_dir}/epoch_10.pth')
|
|
|
|
# epoch can not be evenly divided by 2
|
|
runner.epoch = 10
|
|
checkpoint_hook.after_train_epoch(runner)
|
|
assert 'last_ckpt' in runner.message_hub.runtime_info and \
|
|
runner.message_hub.get_info('last_ckpt') == (
|
|
f'{work_dir}/epoch_10.pth')
|
|
|
|
# by epoch is False
|
|
runner.epoch = 9
|
|
runner.message_hub = MessageHub.get_instance('test_after_train_epoch1')
|
|
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
|
|
checkpoint_hook.before_train(runner)
|
|
checkpoint_hook.after_train_epoch(runner)
|
|
assert 'last_ckpt' not in runner.message_hub.runtime_info
|
|
|
|
# # max_keep_ckpts > 0
|
|
runner.work_dir = work_dir
|
|
os.system(f'touch {work_dir}/epoch_8.pth')
|
|
checkpoint_hook = CheckpointHook(
|
|
interval=2, by_epoch=True, max_keep_ckpts=1)
|
|
checkpoint_hook.before_train(runner)
|
|
checkpoint_hook.after_train_epoch(runner)
|
|
assert (runner.epoch + 1) % 2 == 0
|
|
assert not os.path.exists(f'{work_dir}/epoch_8.pth')
|
|
|
|
def test_after_train_iter(self, tmp_path):
|
|
work_dir = str(tmp_path)
|
|
runner = Mock()
|
|
runner.work_dir = str(work_dir)
|
|
runner.iter = 9
|
|
batch_idx = 9
|
|
runner.model = Mock()
|
|
runner.message_hub = MessageHub.get_instance('test_after_train_iter')
|
|
|
|
# by epoch is True
|
|
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
|
|
checkpoint_hook.before_train(runner)
|
|
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
|
assert 'last_ckpt' not in runner.message_hub.runtime_info
|
|
|
|
# by epoch is False
|
|
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
|
|
checkpoint_hook.before_train(runner)
|
|
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
|
assert (runner.iter + 1) % 2 == 0
|
|
assert 'last_ckpt' in runner.message_hub.runtime_info and \
|
|
runner.message_hub.get_info('last_ckpt') == (
|
|
f'{work_dir}/iter_10.pth')
|
|
|
|
# epoch can not be evenly divided by 2
|
|
runner.iter = 10
|
|
checkpoint_hook.after_train_epoch(runner)
|
|
assert 'last_ckpt' in runner.message_hub.runtime_info and \
|
|
runner.message_hub.get_info('last_ckpt') == (
|
|
f'{work_dir}/iter_10.pth')
|
|
|
|
# max_keep_ckpts > 0
|
|
runner.iter = 9
|
|
runner.work_dir = work_dir
|
|
os.system(f'touch {work_dir}/iter_8.pth')
|
|
checkpoint_hook = CheckpointHook(
|
|
interval=2, by_epoch=False, max_keep_ckpts=1)
|
|
checkpoint_hook.before_train(runner)
|
|
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
|
assert not os.path.exists(f'{work_dir}/iter_8.pth')
|