[Feature] Add empty cache hook (#58)

* [Feature]: Add Part3 of Hooks

* [Feature]: Add Hook

* [Fix]: Add docstring and type hint for base hook

* [Fix]: Add test case to not the last iter, inner_iter, epoch

* [Fix]: Add missing type hint

* [Feature]: Add Args and Returns in docstring

* [Fix]: Add missing colon

* [Fix]: Add optional to docstring

* [Fix]: Fix docstring problem

* [Fix]: Fix lint

* fix lint

* update typing and docs

* fix lint

* Update mmengine/hooks/empty_cache_hook.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/hooks/empty_cache_hook.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/hooks/empty_cache_hook.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update tests/test_hook/test_empty_cache_hook.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* fix lint

* fix comments

* remove test condition

Co-authored-by: seuyou <3463423099@qq.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
Yifei Yang 2022-03-02 14:04:41 +08:00 committed by GitHub
parent 63a3af4f8c
commit 94ab45d07e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 81 additions and 1 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .empty_cache_hook import EmptyCacheHook
from .hook import Hook
from .iter_timer_hook import IterTimerHook
from .optimizer_hook import OptimizerHook
@ -7,5 +8,5 @@ from .sampler_seed_hook import DistSamplerSeedHook
__all__ = [
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
'OptimizerHook'
'OptimizerHook', 'EmptyCacheHook'
]

View File

@ -0,0 +1,65 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
import torch
from mmengine.data import BaseDataSample
from mmengine.registry import HOOKS
from .hook import Hook
@HOOKS.register_module()
class EmptyCacheHook(Hook):
"""Releases all unoccupied cached GPU memory during the process of
training.
Args:
before_epoch (bool): Whether to release cache before an epoch. Defaults
to False.
after_epoch (bool): Whether to release cache after an epoch. Defaults
to True.
after_iter (bool): Whether to release cache after an iteration.
Defaults to False.
"""
def __init__(self,
before_epoch: bool = False,
after_epoch: bool = True,
after_iter: bool = False) -> None:
self._before_epoch = before_epoch
self._after_epoch = after_epoch
self._after_iter = after_iter
def after_iter(self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Empty cache after an iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader.
Defaults to None.
outputs (Sequence[BaseDataSample]): Outputs from model.
Defaults to None.
"""
if self._after_iter:
torch.cuda.empty_cache()
def before_epoch(self, runner: object) -> None:
"""Empty cache before an epoch.
Args:
runner (object): The runner of the training process.
"""
if self._before_epoch:
torch.cuda.empty_cache()
def after_epoch(self, runner: object) -> None:
"""Empty cache after an epoch.
Args:
runner (object): The runner of the training process.
"""
if self._after_epoch:
torch.cuda.empty_cache()

View File

@ -0,0 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mock import Mock
from mmengine.hooks import EmptyCacheHook
class TestEmptyCacheHook:
def test_emtpy_cache_hook(self):
Hook = EmptyCacheHook(True, True, True)
Runner = Mock()
Hook.after_iter(Runner)
Hook.before_epoch(Runner)
Hook.after_epoch(Runner)