mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add empty cache hook (#58)
* [Feature]: Add Part3 of Hooks * [Feature]: Add Hook * [Fix]: Add docstring and type hint for base hook * [Fix]: Add test case to not the last iter, inner_iter, epoch * [Fix]: Add missing type hint * [Feature]: Add Args and Returns in docstring * [Fix]: Add missing colon * [Fix]: Add optional to docstring * [Fix]: Fix docstring problem * [Fix]: Fix lint * fix lint * update typing and docs * fix lint * Update mmengine/hooks/empty_cache_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/hooks/empty_cache_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/hooks/empty_cache_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_hook/test_empty_cache_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * fix lint * fix comments * remove test condition Co-authored-by: seuyou <3463423099@qq.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
parent
63a3af4f8c
commit
94ab45d07e
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .empty_cache_hook import EmptyCacheHook
|
||||
from .hook import Hook
|
||||
from .iter_timer_hook import IterTimerHook
|
||||
from .optimizer_hook import OptimizerHook
|
||||
@ -7,5 +8,5 @@ from .sampler_seed_hook import DistSamplerSeedHook
|
||||
|
||||
__all__ = [
|
||||
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
|
||||
'OptimizerHook'
|
||||
'OptimizerHook', 'EmptyCacheHook'
|
||||
]
|
||||
|
65
mmengine/hooks/empty_cache_hook.py
Normal file
65
mmengine/hooks/empty_cache_hook.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class EmptyCacheHook(Hook):
|
||||
"""Releases all unoccupied cached GPU memory during the process of
|
||||
training.
|
||||
|
||||
Args:
|
||||
before_epoch (bool): Whether to release cache before an epoch. Defaults
|
||||
to False.
|
||||
after_epoch (bool): Whether to release cache after an epoch. Defaults
|
||||
to True.
|
||||
after_iter (bool): Whether to release cache after an iteration.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
before_epoch: bool = False,
|
||||
after_epoch: bool = True,
|
||||
after_iter: bool = False) -> None:
|
||||
self._before_epoch = before_epoch
|
||||
self._after_epoch = after_epoch
|
||||
self._after_iter = after_iter
|
||||
|
||||
def after_iter(self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""Empty cache after an iteration.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample]): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (Sequence[BaseDataSample]): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
if self._after_iter:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def before_epoch(self, runner: object) -> None:
|
||||
"""Empty cache before an epoch.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
"""
|
||||
if self._before_epoch:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def after_epoch(self, runner: object) -> None:
|
||||
"""Empty cache after an epoch.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
"""
|
||||
if self._after_epoch:
|
||||
torch.cuda.empty_cache()
|
14
tests/test_hook/test_empty_cache_hook.py
Normal file
14
tests/test_hook/test_empty_cache_hook.py
Normal file
@ -0,0 +1,14 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mock import Mock
|
||||
|
||||
from mmengine.hooks import EmptyCacheHook
|
||||
|
||||
|
||||
class TestEmptyCacheHook:
|
||||
|
||||
def test_emtpy_cache_hook(self):
|
||||
Hook = EmptyCacheHook(True, True, True)
|
||||
Runner = Mock()
|
||||
Hook.after_iter(Runner)
|
||||
Hook.before_epoch(Runner)
|
||||
Hook.after_epoch(Runner)
|
Loading…
x
Reference in New Issue
Block a user