mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add empty cache hook (#58)
* [Feature]: Add Part3 of Hooks * [Feature]: Add Hook * [Fix]: Add docstring and type hint for base hook * [Fix]: Add test case to not the last iter, inner_iter, epoch * [Fix]: Add missing type hint * [Feature]: Add Args and Returns in docstring * [Fix]: Add missing colon * [Fix]: Add optional to docstring * [Fix]: Fix docstring problem * [Fix]: Fix lint * fix lint * update typing and docs * fix lint * Update mmengine/hooks/empty_cache_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/hooks/empty_cache_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/hooks/empty_cache_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_hook/test_empty_cache_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * fix lint * fix comments * remove test condition Co-authored-by: seuyou <3463423099@qq.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
parent
63a3af4f8c
commit
94ab45d07e
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .empty_cache_hook import EmptyCacheHook
|
||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
from .iter_timer_hook import IterTimerHook
|
from .iter_timer_hook import IterTimerHook
|
||||||
from .optimizer_hook import OptimizerHook
|
from .optimizer_hook import OptimizerHook
|
||||||
@ -7,5 +8,5 @@ from .sampler_seed_hook import DistSamplerSeedHook
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
|
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
|
||||||
'OptimizerHook'
|
'OptimizerHook', 'EmptyCacheHook'
|
||||||
]
|
]
|
||||||
|
65
mmengine/hooks/empty_cache_hook.py
Normal file
65
mmengine/hooks/empty_cache_hook.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmengine.data import BaseDataSample
|
||||||
|
from mmengine.registry import HOOKS
|
||||||
|
from .hook import Hook
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class EmptyCacheHook(Hook):
|
||||||
|
"""Releases all unoccupied cached GPU memory during the process of
|
||||||
|
training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
before_epoch (bool): Whether to release cache before an epoch. Defaults
|
||||||
|
to False.
|
||||||
|
after_epoch (bool): Whether to release cache after an epoch. Defaults
|
||||||
|
to True.
|
||||||
|
after_iter (bool): Whether to release cache after an iteration.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
before_epoch: bool = False,
|
||||||
|
after_epoch: bool = True,
|
||||||
|
after_iter: bool = False) -> None:
|
||||||
|
self._before_epoch = before_epoch
|
||||||
|
self._after_epoch = after_epoch
|
||||||
|
self._after_iter = after_iter
|
||||||
|
|
||||||
|
def after_iter(self,
|
||||||
|
runner: object,
|
||||||
|
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||||
|
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||||
|
"""Empty cache after an iteration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (object): The runner of the training process.
|
||||||
|
data_batch (Sequence[BaseDataSample]): Data from dataloader.
|
||||||
|
Defaults to None.
|
||||||
|
outputs (Sequence[BaseDataSample]): Outputs from model.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
if self._after_iter:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def before_epoch(self, runner: object) -> None:
|
||||||
|
"""Empty cache before an epoch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (object): The runner of the training process.
|
||||||
|
"""
|
||||||
|
if self._before_epoch:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def after_epoch(self, runner: object) -> None:
|
||||||
|
"""Empty cache after an epoch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (object): The runner of the training process.
|
||||||
|
"""
|
||||||
|
if self._after_epoch:
|
||||||
|
torch.cuda.empty_cache()
|
14
tests/test_hook/test_empty_cache_hook.py
Normal file
14
tests/test_hook/test_empty_cache_hook.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from mock import Mock
|
||||||
|
|
||||||
|
from mmengine.hooks import EmptyCacheHook
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmptyCacheHook:
|
||||||
|
|
||||||
|
def test_emtpy_cache_hook(self):
|
||||||
|
Hook = EmptyCacheHook(True, True, True)
|
||||||
|
Runner = Mock()
|
||||||
|
Hook.after_iter(Runner)
|
||||||
|
Hook.before_epoch(Runner)
|
||||||
|
Hook.after_epoch(Runner)
|
Loading…
x
Reference in New Issue
Block a user