diff --git a/docs/en/api/hooks.rst b/docs/en/api/hooks.rst index 381f196f..3e996d62 100644 --- a/docs/en/api/hooks.rst +++ b/docs/en/api/hooks.rst @@ -25,3 +25,4 @@ mmengine.hooks ProfilerHook NPUProfilerHook PrepareTTAHook + EarlyStoppingHook diff --git a/docs/zh_cn/api/hooks.rst b/docs/zh_cn/api/hooks.rst index 381f196f..3e996d62 100644 --- a/docs/zh_cn/api/hooks.rst +++ b/docs/zh_cn/api/hooks.rst @@ -25,3 +25,4 @@ mmengine.hooks ProfilerHook NPUProfilerHook PrepareTTAHook + EarlyStoppingHook diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py index 1f892993..746be6b0 100644 --- a/mmengine/hooks/__init__.py +++ b/mmengine/hooks/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .checkpoint_hook import CheckpointHook +from .early_stopping_hook import EarlyStoppingHook from .ema_hook import EMAHook from .empty_cache_hook import EmptyCacheHook from .hook import Hook @@ -17,5 +18,5 @@ __all__ = [ 'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', 'LoggerHook', 'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook', 'ProfilerHook', - 'NPUProfilerHook', 'PrepareTTAHook' + 'PrepareTTAHook', 'NPUProfilerHook', 'EarlyStoppingHook' ] diff --git a/mmengine/hooks/early_stopping_hook.py b/mmengine/hooks/early_stopping_hook.py new file mode 100644 index 00000000..e047ccbb --- /dev/null +++ b/mmengine/hooks/early_stopping_hook.py @@ -0,0 +1,159 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from math import inf, isfinite +from typing import Optional, Tuple, Union + +from mmengine.registry import HOOKS +from .hook import Hook + +DATA_BATCH = Optional[Union[dict, tuple, list]] + + +@HOOKS.register_module() +class EarlyStoppingHook(Hook): + """Early stop the training when the monitored metric reached a plateau. + + Args: + monitor (str): The monitored metric key to decide early stopping. + rule (str, optional): Comparison rule. Options are 'greater', + 'less'. Defaults to None. + min_delta (float, optional): Minimum difference to continue the + training. Defaults to 0.01. + strict (bool, optional): Whether to crash the training when `monitor` + is not found in the `metrics`. Defaults to False. + check_finite: Whether to stop training when the monitor becomes NaN or + infinite. Defaults to True. + patience (int, optional): The times of validation with no improvement + after which training will be stopped. Defaults to 5. + stopping_threshold (float, optional): Stop training immediately once + the monitored quantity reaches this threshold. Defaults to None. + + Note: + `New in version 0.7.0.` + """ + priority = 'LOWEST' + + rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y} + _default_greater_keys = [ + 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU', + 'mAcc', 'aAcc' + ] + _default_less_keys = ['loss'] + + def __init__( + self, + monitor: str, + rule: Optional[str] = None, + min_delta: float = 0.1, + strict: bool = False, + check_finite: bool = True, + patience: int = 5, + stopping_threshold: Optional[float] = None, + ): + + self.monitor = monitor + if rule is not None: + if rule not in ['greater', 'less']: + raise ValueError( + '`rule` should be either "greater" or "less", ' + f'but got {rule}') + else: + rule = self._init_rule(monitor) + self.rule = rule + self.min_delta = min_delta if rule == 'greater' else -1 * min_delta + self.strict = strict + self.check_finite = check_finite + self.patience = patience + self.stopping_threshold = stopping_threshold + + self.wait_count = 0 + self.best_score = -inf if rule == 'greater' else inf + + def _init_rule(self, monitor: str) -> str: + greater_keys = {key.lower() for key in self._default_greater_keys} + less_keys = {key.lower() for key in self._default_less_keys} + monitor_lc = monitor.lower() + if monitor_lc in greater_keys: + rule = 'greater' + elif monitor_lc in less_keys: + rule = 'less' + elif any(key in monitor_lc for key in greater_keys): + rule = 'greater' + elif any(key in monitor_lc for key in less_keys): + rule = 'less' + else: + raise ValueError(f'Cannot infer the rule for {monitor}, thus rule ' + 'must be specified.') + return rule + + def _check_stop_condition(self, current_score: float) -> Tuple[bool, str]: + compare = self.rule_map[self.rule] + stop_training = False + reason_message = '' + + if self.check_finite and not isfinite(current_score): + stop_training = True + reason_message = (f'Monitored metric {self.monitor} = ' + f'{current_score} is infinite. ' + f'Previous best value was ' + f'{self.best_score:.3f}.') + + elif self.stopping_threshold is not None and compare( + current_score, self.stopping_threshold): + stop_training = True + self.best_score = current_score + reason_message = (f'Stopping threshold reached: ' + f'`{self.monitor}` = {current_score} is ' + f'{self.rule} than {self.stopping_threshold}.') + elif compare(self.best_score + self.min_delta, current_score): + + self.wait_count += 1 + + if self.wait_count >= self.patience: + reason_message = (f'the monitored metric did not improve ' + f'in the last {self.wait_count} records. ' + f'best score: {self.best_score:.3f}. ') + stop_training = True + else: + self.best_score = current_score + self.wait_count = 0 + + return stop_training, reason_message + + def before_run(self, runner) -> None: + """Check `stop_training` variable in `runner.train_loop`. + + Args: + runner (Runner): The runner of the training process. + """ + + assert hasattr(runner.train_loop, 'stop_training'), \ + '`train_loop` should contain `stop_training` variable.' + + def after_val_epoch(self, runner, metrics): + """Decide whether to stop the training process. + + Args: + runner (Runner): The runner of the training process. + metrics (dict): Evaluation results of all metrics + """ + + if self.monitor not in metrics: + if self.strict: + raise RuntimeError( + 'Early stopping conditioned on metric ' + f'`{self.monitor} is not available. Please check available' + f' metrics {metrics}, or set `strict=False` in ' + '`EarlyStoppingHook`.') + warnings.warn( + 'Skip early stopping process since the evaluation ' + f'results ({metrics.keys()}) do not include `monitor` ' + f'({self.monitor}).') + return + + current_score = metrics[self.monitor] + + stop_training, message = self._check_stop_condition(current_score) + if stop_training: + runner.train_loop.stop_training = True + runner.logger.info(message) diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 67b8d5b6..a274fa6d 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -49,6 +49,9 @@ class EpochBasedTrainLoop(BaseLoop): self._iter = 0 self.val_begin = val_begin self.val_interval = val_interval + # This attribute will be updated by `EarlyStoppingHook` + # when it is enabled. + self.stop_training = False if hasattr(self.dataloader.dataset, 'metainfo'): self.runner.visualizer.dataset_meta = \ self.dataloader.dataset.metainfo @@ -86,7 +89,7 @@ class EpochBasedTrainLoop(BaseLoop): """Launch training.""" self.runner.call_hook('before_train') - while self._epoch < self._max_epochs: + while self._epoch < self._max_epochs and not self.stop_training: self.run_epoch() self._decide_current_val_interval() @@ -216,6 +219,9 @@ class IterBasedTrainLoop(BaseLoop): self._iter = 0 self.val_begin = val_begin self.val_interval = val_interval + # This attribute will be updated by `EarlyStoppingHook` + # when it is enabled. + self.stop_training = False if hasattr(self.dataloader.dataset, 'metainfo'): self.runner.visualizer.dataset_meta = \ self.dataloader.dataset.metainfo @@ -257,7 +263,7 @@ class IterBasedTrainLoop(BaseLoop): # In iteration-based training loop, we treat the whole training process # as a big epoch and execute the corresponding hook. self.runner.call_hook('before_train_epoch') - while self._iter < self._max_iters: + while self._iter < self._max_iters and not self.stop_training: self.runner.model.train() data_batch = next(self.dataloader_iterator) diff --git a/tests/test_hooks/test_early_stopping_hook.py b/tests/test_hooks/test_early_stopping_hook.py new file mode 100644 index 00000000..16f8fd98 --- /dev/null +++ b/tests/test_hooks/test_early_stopping_hook.py @@ -0,0 +1,255 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +import math +import os.path as osp +import tempfile +from unittest.mock import Mock + +import torch +import torch.nn as nn +from torch.utils.data import Dataset + +from mmengine.evaluator import BaseMetric +from mmengine.hooks import EarlyStoppingHook +from mmengine.logging import MMLogger +from mmengine.model import BaseModel +from mmengine.optim import OptimWrapper +from mmengine.runner import Runner +from mmengine.testing import RunnerTestCase + + +class ToyModel(BaseModel): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 1) + + def forward(self, inputs, data_sample, mode='tensor'): + labels = torch.stack(data_sample) + inputs = torch.stack(inputs) + outputs = self.linear(inputs) + if mode == 'tensor': + return outputs + elif mode == 'loss': + loss = (labels - outputs).sum() + outputs = dict(loss=loss) + return outputs + else: + return outputs + + +class DummyDataset(Dataset): + METAINFO = dict() # type: ignore + data = torch.randn(12, 2) + label = torch.ones(12) + + @property + def metainfo(self): + return self.METAINFO + + def __len__(self): + return self.data.size(0) + + def __getitem__(self, index): + return dict(inputs=self.data[index], data_sample=self.label[index]) + + +class DummyMetric(BaseMetric): + + default_prefix: str = 'test' + + def __init__(self, length): + super().__init__() + self.length = length + self.best_idx = length // 2 + self.cur_idx = 0 + self.vals = [90, 91, 92, 88, 89, 90] * 2 + + def process(self, *args, **kwargs): + self.results.append(0) + + def compute_metrics(self, *args, **kwargs): + acc = self.vals[self.cur_idx] + self.cur_idx += 1 + return dict(acc=acc) + + +def get_mock_runner(): + runner = Mock() + runner.train_loop = Mock() + runner.train_loop.stop_training = False + return runner + + +class TestEarlyStoppingHook(RunnerTestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + # `FileHandler` should be closed in Windows, otherwise we cannot + # delete the temporary directory + logging.shutdown() + MMLogger._instance_dict.clear() + self.temp_dir.cleanup() + + def test_init(self): + + hook = EarlyStoppingHook(monitor='acc') + self.assertEqual(hook.rule, 'greater') + self.assertLess(hook.best_score, 0) + + hook = EarlyStoppingHook(monitor='ACC') + self.assertEqual(hook.rule, 'greater') + self.assertLess(hook.best_score, 0) + + hook = EarlyStoppingHook(monitor='mAP_50') + self.assertEqual(hook.rule, 'greater') + self.assertLess(hook.best_score, 0) + + hook = EarlyStoppingHook(monitor='loss') + self.assertEqual(hook.rule, 'less') + self.assertGreater(hook.best_score, 0) + + hook = EarlyStoppingHook(monitor='Loss') + self.assertEqual(hook.rule, 'less') + self.assertGreater(hook.best_score, 0) + + hook = EarlyStoppingHook(monitor='ce_loss') + self.assertEqual(hook.rule, 'less') + self.assertGreater(hook.best_score, 0) + + with self.assertRaises(ValueError): + # `rule` should be passed. + EarlyStoppingHook(monitor='recall') + + with self.assertRaises(ValueError): + # Invalid `rule` + EarlyStoppingHook(monitor='accuracy/top1', rule='the world') + + def test_before_run(self): + runner = Mock() + runner.train_loop = object() + + # `train_loop` must contain `stop_training` variable. + with self.assertRaises(AssertionError): + hook = EarlyStoppingHook(monitor='accuracy/top1', rule='greater') + hook.before_run(runner) + + def test_after_val_epoch(self): + runner = get_mock_runner() + metrics = {'accuracy/top1': 0.5, 'loss': 0.23} + hook = EarlyStoppingHook(monitor='acc', rule='greater') + + with self.assertWarns(UserWarning): + # Skip early stopping process since the evaluation results does not + # include the key 'acc' + hook.after_val_epoch(runner, metrics) + + # if `monitor` does not match and strict=True, crash the training. + with self.assertRaises(RuntimeError): + metrics = {'accuracy/top1': 0.5, 'loss': 0.23} + hook = EarlyStoppingHook( + monitor='acc', rule='greater', strict=True) + hook.after_val_epoch(runner, metrics) + + # Check largest value + runner = get_mock_runner() + metrics = [{'accuracy/top1': i / 9.} for i in range(8)] + hook = EarlyStoppingHook(monitor='accuracy/top1', rule='greater') + for metric in metrics: + hook.after_val_epoch(runner, metric) + if runner.train_loop.stop_training: + break + self.assertAlmostEqual(hook.best_score, 7 / 9) + + # Check smallest value + runner = get_mock_runner() + metrics = [{'loss': i / 9.} for i in range(8, 0, -1)] + hook = EarlyStoppingHook(monitor='loss') + for metric in metrics: + hook.after_val_epoch(runner, metric) + if runner.train_loop.stop_training: + break + self.assertAlmostEqual(hook.best_score, 1 / 9) + + # Check stop training + runner = get_mock_runner() + metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)] + hook = EarlyStoppingHook( + monitor='accuracy/top1', rule='greater', min_delta=1) + for metric in metrics: + hook.after_val_epoch(runner, metric) + if runner.train_loop.stop_training: + break + self.assertTrue(runner.train_loop.stop_training) + + # Check finite + runner = get_mock_runner() + metrics = [{'accuracy/top1': math.inf} for i in range(5)] + hook = EarlyStoppingHook( + monitor='accuracy/top1', rule='greater', min_delta=1) + for metric in metrics: + hook.after_val_epoch(runner, metric) + if runner.train_loop.stop_training: + break + self.assertTrue(runner.train_loop.stop_training) + + # Check patience + runner = get_mock_runner() + metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)] + hook = EarlyStoppingHook( + monitor='accuracy/top1', rule='greater', min_delta=1, patience=10) + for metric in metrics: + hook.after_val_epoch(runner, metric) + if runner.train_loop.stop_training: + break + self.assertFalse(runner.train_loop.stop_training) + + # Check stopping_threshold + runner = get_mock_runner() + metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)] + hook = EarlyStoppingHook( + monitor='accuracy/top1', + rule='greater', + stopping_threshold=98.5, + patience=0) + for metric in metrics: + hook.after_val_epoch(runner, metric) + if runner.train_loop.stop_training: + break + self.assertAlmostEqual(hook.best_score.item(), 98 + 4 / 7, places=5) + + def test_with_runner(self): + max_epoch = 10 + work_dir = osp.join(self.temp_dir.name, 'runner_test') + early_stop_cfg = dict( + type='EarlyStoppingHook', + monitor='test/acc', + rule='greater', + min_delta=1, + patience=3, + ) + runner = Runner( + model=ToyModel(), + work_dir=work_dir, + train_dataloader=dict( + dataset=DummyDataset(), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=3, + num_workers=0), + val_dataloader=dict( + dataset=DummyDataset(), + sampler=dict(type='DefaultSampler', shuffle=False), + batch_size=3, + num_workers=0), + val_evaluator=dict(type=DummyMetric, length=max_epoch), + optim_wrapper=OptimWrapper( + torch.optim.Adam(ToyModel().parameters())), + train_cfg=dict( + by_epoch=True, max_epochs=max_epoch, val_interval=1), + val_cfg=dict(), + custom_hooks=[early_stop_cfg], + experiment_name='earlystop_test') + runner.train() + self.assertEqual(runner.epoch, 6)