mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Support EarlyStoppingHook (#739)
* [Feature] EarlyStoppingHook * delete redundant line * Assert stop_training and rename tests * Fix UT * rename `metric` to `monitor` * Fix UT * Fix UT * edit docstring on patience * Draft for new code * fix ut * add test case * add test case * fix ut * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> * Append hook * Append hook * Apply suggestions * Update suggestions * Update mmengine/hooks/__init__.py * fix min_delta * Apply suggestions from code review * lint * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * delete save_last * infer rule more robust * refine unit test * Update mmengine/hooks/early_stopping_hook.py --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Co-authored-by: zhouzaida <zhouzaida@163.com> Co-authored-by: HAOCHENYE <21724054@zju.edu.cn>
This commit is contained in:
parent
d34fb58773
commit
b3430e4257
@ -25,3 +25,4 @@ mmengine.hooks
|
|||||||
ProfilerHook
|
ProfilerHook
|
||||||
NPUProfilerHook
|
NPUProfilerHook
|
||||||
PrepareTTAHook
|
PrepareTTAHook
|
||||||
|
EarlyStoppingHook
|
||||||
|
@ -25,3 +25,4 @@ mmengine.hooks
|
|||||||
ProfilerHook
|
ProfilerHook
|
||||||
NPUProfilerHook
|
NPUProfilerHook
|
||||||
PrepareTTAHook
|
PrepareTTAHook
|
||||||
|
EarlyStoppingHook
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .checkpoint_hook import CheckpointHook
|
from .checkpoint_hook import CheckpointHook
|
||||||
|
from .early_stopping_hook import EarlyStoppingHook
|
||||||
from .ema_hook import EMAHook
|
from .ema_hook import EMAHook
|
||||||
from .empty_cache_hook import EmptyCacheHook
|
from .empty_cache_hook import EmptyCacheHook
|
||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
@ -17,5 +18,5 @@ __all__ = [
|
|||||||
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
|
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
|
||||||
'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', 'LoggerHook',
|
'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', 'LoggerHook',
|
||||||
'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook', 'ProfilerHook',
|
'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook', 'ProfilerHook',
|
||||||
'NPUProfilerHook', 'PrepareTTAHook'
|
'PrepareTTAHook', 'NPUProfilerHook', 'EarlyStoppingHook'
|
||||||
]
|
]
|
||||||
|
159
mmengine/hooks/early_stopping_hook.py
Normal file
159
mmengine/hooks/early_stopping_hook.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import warnings
|
||||||
|
from math import inf, isfinite
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
from mmengine.registry import HOOKS
|
||||||
|
from .hook import Hook
|
||||||
|
|
||||||
|
DATA_BATCH = Optional[Union[dict, tuple, list]]
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class EarlyStoppingHook(Hook):
|
||||||
|
"""Early stop the training when the monitored metric reached a plateau.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
monitor (str): The monitored metric key to decide early stopping.
|
||||||
|
rule (str, optional): Comparison rule. Options are 'greater',
|
||||||
|
'less'. Defaults to None.
|
||||||
|
min_delta (float, optional): Minimum difference to continue the
|
||||||
|
training. Defaults to 0.01.
|
||||||
|
strict (bool, optional): Whether to crash the training when `monitor`
|
||||||
|
is not found in the `metrics`. Defaults to False.
|
||||||
|
check_finite: Whether to stop training when the monitor becomes NaN or
|
||||||
|
infinite. Defaults to True.
|
||||||
|
patience (int, optional): The times of validation with no improvement
|
||||||
|
after which training will be stopped. Defaults to 5.
|
||||||
|
stopping_threshold (float, optional): Stop training immediately once
|
||||||
|
the monitored quantity reaches this threshold. Defaults to None.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
`New in version 0.7.0.`
|
||||||
|
"""
|
||||||
|
priority = 'LOWEST'
|
||||||
|
|
||||||
|
rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
|
||||||
|
_default_greater_keys = [
|
||||||
|
'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
|
||||||
|
'mAcc', 'aAcc'
|
||||||
|
]
|
||||||
|
_default_less_keys = ['loss']
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
monitor: str,
|
||||||
|
rule: Optional[str] = None,
|
||||||
|
min_delta: float = 0.1,
|
||||||
|
strict: bool = False,
|
||||||
|
check_finite: bool = True,
|
||||||
|
patience: int = 5,
|
||||||
|
stopping_threshold: Optional[float] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
self.monitor = monitor
|
||||||
|
if rule is not None:
|
||||||
|
if rule not in ['greater', 'less']:
|
||||||
|
raise ValueError(
|
||||||
|
'`rule` should be either "greater" or "less", '
|
||||||
|
f'but got {rule}')
|
||||||
|
else:
|
||||||
|
rule = self._init_rule(monitor)
|
||||||
|
self.rule = rule
|
||||||
|
self.min_delta = min_delta if rule == 'greater' else -1 * min_delta
|
||||||
|
self.strict = strict
|
||||||
|
self.check_finite = check_finite
|
||||||
|
self.patience = patience
|
||||||
|
self.stopping_threshold = stopping_threshold
|
||||||
|
|
||||||
|
self.wait_count = 0
|
||||||
|
self.best_score = -inf if rule == 'greater' else inf
|
||||||
|
|
||||||
|
def _init_rule(self, monitor: str) -> str:
|
||||||
|
greater_keys = {key.lower() for key in self._default_greater_keys}
|
||||||
|
less_keys = {key.lower() for key in self._default_less_keys}
|
||||||
|
monitor_lc = monitor.lower()
|
||||||
|
if monitor_lc in greater_keys:
|
||||||
|
rule = 'greater'
|
||||||
|
elif monitor_lc in less_keys:
|
||||||
|
rule = 'less'
|
||||||
|
elif any(key in monitor_lc for key in greater_keys):
|
||||||
|
rule = 'greater'
|
||||||
|
elif any(key in monitor_lc for key in less_keys):
|
||||||
|
rule = 'less'
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Cannot infer the rule for {monitor}, thus rule '
|
||||||
|
'must be specified.')
|
||||||
|
return rule
|
||||||
|
|
||||||
|
def _check_stop_condition(self, current_score: float) -> Tuple[bool, str]:
|
||||||
|
compare = self.rule_map[self.rule]
|
||||||
|
stop_training = False
|
||||||
|
reason_message = ''
|
||||||
|
|
||||||
|
if self.check_finite and not isfinite(current_score):
|
||||||
|
stop_training = True
|
||||||
|
reason_message = (f'Monitored metric {self.monitor} = '
|
||||||
|
f'{current_score} is infinite. '
|
||||||
|
f'Previous best value was '
|
||||||
|
f'{self.best_score:.3f}.')
|
||||||
|
|
||||||
|
elif self.stopping_threshold is not None and compare(
|
||||||
|
current_score, self.stopping_threshold):
|
||||||
|
stop_training = True
|
||||||
|
self.best_score = current_score
|
||||||
|
reason_message = (f'Stopping threshold reached: '
|
||||||
|
f'`{self.monitor}` = {current_score} is '
|
||||||
|
f'{self.rule} than {self.stopping_threshold}.')
|
||||||
|
elif compare(self.best_score + self.min_delta, current_score):
|
||||||
|
|
||||||
|
self.wait_count += 1
|
||||||
|
|
||||||
|
if self.wait_count >= self.patience:
|
||||||
|
reason_message = (f'the monitored metric did not improve '
|
||||||
|
f'in the last {self.wait_count} records. '
|
||||||
|
f'best score: {self.best_score:.3f}. ')
|
||||||
|
stop_training = True
|
||||||
|
else:
|
||||||
|
self.best_score = current_score
|
||||||
|
self.wait_count = 0
|
||||||
|
|
||||||
|
return stop_training, reason_message
|
||||||
|
|
||||||
|
def before_run(self, runner) -> None:
|
||||||
|
"""Check `stop_training` variable in `runner.train_loop`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert hasattr(runner.train_loop, 'stop_training'), \
|
||||||
|
'`train_loop` should contain `stop_training` variable.'
|
||||||
|
|
||||||
|
def after_val_epoch(self, runner, metrics):
|
||||||
|
"""Decide whether to stop the training process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training process.
|
||||||
|
metrics (dict): Evaluation results of all metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.monitor not in metrics:
|
||||||
|
if self.strict:
|
||||||
|
raise RuntimeError(
|
||||||
|
'Early stopping conditioned on metric '
|
||||||
|
f'`{self.monitor} is not available. Please check available'
|
||||||
|
f' metrics {metrics}, or set `strict=False` in '
|
||||||
|
'`EarlyStoppingHook`.')
|
||||||
|
warnings.warn(
|
||||||
|
'Skip early stopping process since the evaluation '
|
||||||
|
f'results ({metrics.keys()}) do not include `monitor` '
|
||||||
|
f'({self.monitor}).')
|
||||||
|
return
|
||||||
|
|
||||||
|
current_score = metrics[self.monitor]
|
||||||
|
|
||||||
|
stop_training, message = self._check_stop_condition(current_score)
|
||||||
|
if stop_training:
|
||||||
|
runner.train_loop.stop_training = True
|
||||||
|
runner.logger.info(message)
|
@ -49,6 +49,9 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
self._iter = 0
|
self._iter = 0
|
||||||
self.val_begin = val_begin
|
self.val_begin = val_begin
|
||||||
self.val_interval = val_interval
|
self.val_interval = val_interval
|
||||||
|
# This attribute will be updated by `EarlyStoppingHook`
|
||||||
|
# when it is enabled.
|
||||||
|
self.stop_training = False
|
||||||
if hasattr(self.dataloader.dataset, 'metainfo'):
|
if hasattr(self.dataloader.dataset, 'metainfo'):
|
||||||
self.runner.visualizer.dataset_meta = \
|
self.runner.visualizer.dataset_meta = \
|
||||||
self.dataloader.dataset.metainfo
|
self.dataloader.dataset.metainfo
|
||||||
@ -86,7 +89,7 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
"""Launch training."""
|
"""Launch training."""
|
||||||
self.runner.call_hook('before_train')
|
self.runner.call_hook('before_train')
|
||||||
|
|
||||||
while self._epoch < self._max_epochs:
|
while self._epoch < self._max_epochs and not self.stop_training:
|
||||||
self.run_epoch()
|
self.run_epoch()
|
||||||
|
|
||||||
self._decide_current_val_interval()
|
self._decide_current_val_interval()
|
||||||
@ -216,6 +219,9 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
self._iter = 0
|
self._iter = 0
|
||||||
self.val_begin = val_begin
|
self.val_begin = val_begin
|
||||||
self.val_interval = val_interval
|
self.val_interval = val_interval
|
||||||
|
# This attribute will be updated by `EarlyStoppingHook`
|
||||||
|
# when it is enabled.
|
||||||
|
self.stop_training = False
|
||||||
if hasattr(self.dataloader.dataset, 'metainfo'):
|
if hasattr(self.dataloader.dataset, 'metainfo'):
|
||||||
self.runner.visualizer.dataset_meta = \
|
self.runner.visualizer.dataset_meta = \
|
||||||
self.dataloader.dataset.metainfo
|
self.dataloader.dataset.metainfo
|
||||||
@ -257,7 +263,7 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
# In iteration-based training loop, we treat the whole training process
|
# In iteration-based training loop, we treat the whole training process
|
||||||
# as a big epoch and execute the corresponding hook.
|
# as a big epoch and execute the corresponding hook.
|
||||||
self.runner.call_hook('before_train_epoch')
|
self.runner.call_hook('before_train_epoch')
|
||||||
while self._iter < self._max_iters:
|
while self._iter < self._max_iters and not self.stop_training:
|
||||||
self.runner.model.train()
|
self.runner.model.train()
|
||||||
|
|
||||||
data_batch = next(self.dataloader_iterator)
|
data_batch = next(self.dataloader_iterator)
|
||||||
|
255
tests/test_hooks/test_early_stopping_hook.py
Normal file
255
tests/test_hooks/test_early_stopping_hook.py
Normal file
@ -0,0 +1,255 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os.path as osp
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from mmengine.evaluator import BaseMetric
|
||||||
|
from mmengine.hooks import EarlyStoppingHook
|
||||||
|
from mmengine.logging import MMLogger
|
||||||
|
from mmengine.model import BaseModel
|
||||||
|
from mmengine.optim import OptimWrapper
|
||||||
|
from mmengine.runner import Runner
|
||||||
|
from mmengine.testing import RunnerTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class ToyModel(BaseModel):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(2, 1)
|
||||||
|
|
||||||
|
def forward(self, inputs, data_sample, mode='tensor'):
|
||||||
|
labels = torch.stack(data_sample)
|
||||||
|
inputs = torch.stack(inputs)
|
||||||
|
outputs = self.linear(inputs)
|
||||||
|
if mode == 'tensor':
|
||||||
|
return outputs
|
||||||
|
elif mode == 'loss':
|
||||||
|
loss = (labels - outputs).sum()
|
||||||
|
outputs = dict(loss=loss)
|
||||||
|
return outputs
|
||||||
|
else:
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class DummyDataset(Dataset):
|
||||||
|
METAINFO = dict() # type: ignore
|
||||||
|
data = torch.randn(12, 2)
|
||||||
|
label = torch.ones(12)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metainfo(self):
|
||||||
|
return self.METAINFO
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.data.size(0)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||||
|
|
||||||
|
|
||||||
|
class DummyMetric(BaseMetric):
|
||||||
|
|
||||||
|
default_prefix: str = 'test'
|
||||||
|
|
||||||
|
def __init__(self, length):
|
||||||
|
super().__init__()
|
||||||
|
self.length = length
|
||||||
|
self.best_idx = length // 2
|
||||||
|
self.cur_idx = 0
|
||||||
|
self.vals = [90, 91, 92, 88, 89, 90] * 2
|
||||||
|
|
||||||
|
def process(self, *args, **kwargs):
|
||||||
|
self.results.append(0)
|
||||||
|
|
||||||
|
def compute_metrics(self, *args, **kwargs):
|
||||||
|
acc = self.vals[self.cur_idx]
|
||||||
|
self.cur_idx += 1
|
||||||
|
return dict(acc=acc)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_runner():
|
||||||
|
runner = Mock()
|
||||||
|
runner.train_loop = Mock()
|
||||||
|
runner.train_loop.stop_training = False
|
||||||
|
return runner
|
||||||
|
|
||||||
|
|
||||||
|
class TestEarlyStoppingHook(RunnerTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.temp_dir = tempfile.TemporaryDirectory()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
# `FileHandler` should be closed in Windows, otherwise we cannot
|
||||||
|
# delete the temporary directory
|
||||||
|
logging.shutdown()
|
||||||
|
MMLogger._instance_dict.clear()
|
||||||
|
self.temp_dir.cleanup()
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
|
||||||
|
hook = EarlyStoppingHook(monitor='acc')
|
||||||
|
self.assertEqual(hook.rule, 'greater')
|
||||||
|
self.assertLess(hook.best_score, 0)
|
||||||
|
|
||||||
|
hook = EarlyStoppingHook(monitor='ACC')
|
||||||
|
self.assertEqual(hook.rule, 'greater')
|
||||||
|
self.assertLess(hook.best_score, 0)
|
||||||
|
|
||||||
|
hook = EarlyStoppingHook(monitor='mAP_50')
|
||||||
|
self.assertEqual(hook.rule, 'greater')
|
||||||
|
self.assertLess(hook.best_score, 0)
|
||||||
|
|
||||||
|
hook = EarlyStoppingHook(monitor='loss')
|
||||||
|
self.assertEqual(hook.rule, 'less')
|
||||||
|
self.assertGreater(hook.best_score, 0)
|
||||||
|
|
||||||
|
hook = EarlyStoppingHook(monitor='Loss')
|
||||||
|
self.assertEqual(hook.rule, 'less')
|
||||||
|
self.assertGreater(hook.best_score, 0)
|
||||||
|
|
||||||
|
hook = EarlyStoppingHook(monitor='ce_loss')
|
||||||
|
self.assertEqual(hook.rule, 'less')
|
||||||
|
self.assertGreater(hook.best_score, 0)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# `rule` should be passed.
|
||||||
|
EarlyStoppingHook(monitor='recall')
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# Invalid `rule`
|
||||||
|
EarlyStoppingHook(monitor='accuracy/top1', rule='the world')
|
||||||
|
|
||||||
|
def test_before_run(self):
|
||||||
|
runner = Mock()
|
||||||
|
runner.train_loop = object()
|
||||||
|
|
||||||
|
# `train_loop` must contain `stop_training` variable.
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
hook = EarlyStoppingHook(monitor='accuracy/top1', rule='greater')
|
||||||
|
hook.before_run(runner)
|
||||||
|
|
||||||
|
def test_after_val_epoch(self):
|
||||||
|
runner = get_mock_runner()
|
||||||
|
metrics = {'accuracy/top1': 0.5, 'loss': 0.23}
|
||||||
|
hook = EarlyStoppingHook(monitor='acc', rule='greater')
|
||||||
|
|
||||||
|
with self.assertWarns(UserWarning):
|
||||||
|
# Skip early stopping process since the evaluation results does not
|
||||||
|
# include the key 'acc'
|
||||||
|
hook.after_val_epoch(runner, metrics)
|
||||||
|
|
||||||
|
# if `monitor` does not match and strict=True, crash the training.
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
metrics = {'accuracy/top1': 0.5, 'loss': 0.23}
|
||||||
|
hook = EarlyStoppingHook(
|
||||||
|
monitor='acc', rule='greater', strict=True)
|
||||||
|
hook.after_val_epoch(runner, metrics)
|
||||||
|
|
||||||
|
# Check largest value
|
||||||
|
runner = get_mock_runner()
|
||||||
|
metrics = [{'accuracy/top1': i / 9.} for i in range(8)]
|
||||||
|
hook = EarlyStoppingHook(monitor='accuracy/top1', rule='greater')
|
||||||
|
for metric in metrics:
|
||||||
|
hook.after_val_epoch(runner, metric)
|
||||||
|
if runner.train_loop.stop_training:
|
||||||
|
break
|
||||||
|
self.assertAlmostEqual(hook.best_score, 7 / 9)
|
||||||
|
|
||||||
|
# Check smallest value
|
||||||
|
runner = get_mock_runner()
|
||||||
|
metrics = [{'loss': i / 9.} for i in range(8, 0, -1)]
|
||||||
|
hook = EarlyStoppingHook(monitor='loss')
|
||||||
|
for metric in metrics:
|
||||||
|
hook.after_val_epoch(runner, metric)
|
||||||
|
if runner.train_loop.stop_training:
|
||||||
|
break
|
||||||
|
self.assertAlmostEqual(hook.best_score, 1 / 9)
|
||||||
|
|
||||||
|
# Check stop training
|
||||||
|
runner = get_mock_runner()
|
||||||
|
metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)]
|
||||||
|
hook = EarlyStoppingHook(
|
||||||
|
monitor='accuracy/top1', rule='greater', min_delta=1)
|
||||||
|
for metric in metrics:
|
||||||
|
hook.after_val_epoch(runner, metric)
|
||||||
|
if runner.train_loop.stop_training:
|
||||||
|
break
|
||||||
|
self.assertTrue(runner.train_loop.stop_training)
|
||||||
|
|
||||||
|
# Check finite
|
||||||
|
runner = get_mock_runner()
|
||||||
|
metrics = [{'accuracy/top1': math.inf} for i in range(5)]
|
||||||
|
hook = EarlyStoppingHook(
|
||||||
|
monitor='accuracy/top1', rule='greater', min_delta=1)
|
||||||
|
for metric in metrics:
|
||||||
|
hook.after_val_epoch(runner, metric)
|
||||||
|
if runner.train_loop.stop_training:
|
||||||
|
break
|
||||||
|
self.assertTrue(runner.train_loop.stop_training)
|
||||||
|
|
||||||
|
# Check patience
|
||||||
|
runner = get_mock_runner()
|
||||||
|
metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)]
|
||||||
|
hook = EarlyStoppingHook(
|
||||||
|
monitor='accuracy/top1', rule='greater', min_delta=1, patience=10)
|
||||||
|
for metric in metrics:
|
||||||
|
hook.after_val_epoch(runner, metric)
|
||||||
|
if runner.train_loop.stop_training:
|
||||||
|
break
|
||||||
|
self.assertFalse(runner.train_loop.stop_training)
|
||||||
|
|
||||||
|
# Check stopping_threshold
|
||||||
|
runner = get_mock_runner()
|
||||||
|
metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)]
|
||||||
|
hook = EarlyStoppingHook(
|
||||||
|
monitor='accuracy/top1',
|
||||||
|
rule='greater',
|
||||||
|
stopping_threshold=98.5,
|
||||||
|
patience=0)
|
||||||
|
for metric in metrics:
|
||||||
|
hook.after_val_epoch(runner, metric)
|
||||||
|
if runner.train_loop.stop_training:
|
||||||
|
break
|
||||||
|
self.assertAlmostEqual(hook.best_score.item(), 98 + 4 / 7, places=5)
|
||||||
|
|
||||||
|
def test_with_runner(self):
|
||||||
|
max_epoch = 10
|
||||||
|
work_dir = osp.join(self.temp_dir.name, 'runner_test')
|
||||||
|
early_stop_cfg = dict(
|
||||||
|
type='EarlyStoppingHook',
|
||||||
|
monitor='test/acc',
|
||||||
|
rule='greater',
|
||||||
|
min_delta=1,
|
||||||
|
patience=3,
|
||||||
|
)
|
||||||
|
runner = Runner(
|
||||||
|
model=ToyModel(),
|
||||||
|
work_dir=work_dir,
|
||||||
|
train_dataloader=dict(
|
||||||
|
dataset=DummyDataset(),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||||
|
batch_size=3,
|
||||||
|
num_workers=0),
|
||||||
|
val_dataloader=dict(
|
||||||
|
dataset=DummyDataset(),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
|
batch_size=3,
|
||||||
|
num_workers=0),
|
||||||
|
val_evaluator=dict(type=DummyMetric, length=max_epoch),
|
||||||
|
optim_wrapper=OptimWrapper(
|
||||||
|
torch.optim.Adam(ToyModel().parameters())),
|
||||||
|
train_cfg=dict(
|
||||||
|
by_epoch=True, max_epochs=max_epoch, val_interval=1),
|
||||||
|
val_cfg=dict(),
|
||||||
|
custom_hooks=[early_stop_cfg],
|
||||||
|
experiment_name='earlystop_test')
|
||||||
|
runner.train()
|
||||||
|
self.assertEqual(runner.epoch, 6)
|
Loading…
x
Reference in New Issue
Block a user