diff --git a/mmcv/runner/hooks/__init__.py b/mmcv/runner/hooks/__init__.py index 67fff16fc..3812e4a05 100644 --- a/mmcv/runner/hooks/__init__.py +++ b/mmcv/runner/hooks/__init__.py @@ -1,15 +1,16 @@ from .hook import Hook -from .checkpoint_saver import CheckpointHook +from .checkpoint import CheckpointHook from .closure import ClosureHook from .lr_updater import LrUpdaterHook -from .optimizer_stepper import OptimizerHook +from .optimizer import OptimizerHook from .iter_timer import IterTimerHook from .sampler_seed import DistSamplerSeedHook +from .memory import EmptyCacheHook from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook, TensorboardLoggerHook) __all__ = [ 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook', - 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'TextLoggerHook', - 'PaviLoggerHook', 'TensorboardLoggerHook' + 'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', + 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook' ] diff --git a/mmcv/runner/hooks/checkpoint_saver.py b/mmcv/runner/hooks/checkpoint.py similarity index 100% rename from mmcv/runner/hooks/checkpoint_saver.py rename to mmcv/runner/hooks/checkpoint.py diff --git a/mmcv/runner/hooks/memory.py b/mmcv/runner/hooks/memory.py new file mode 100644 index 000000000..6bd11d0d6 --- /dev/null +++ b/mmcv/runner/hooks/memory.py @@ -0,0 +1,23 @@ +import torch + +from .hook import Hook + + +class EmptyCacheHook(Hook): + + def __init__(self, before_epoch=False, after_epoch=True, after_iter=False): + self._before_epoch = before_epoch + self._after_epoch = after_epoch + self._after_iter = after_iter + + def after_iter(self, runner): + if self._after_iter: + torch.cuda.empty_cache() + + def before_epoch(self, runner): + if self._before_epoch: + torch.cuda.empty_cache() + + def after_epoch(self, runner): + if self._after_epoch: + torch.cuda.empty_cache() diff --git a/mmcv/runner/hooks/optimizer_stepper.py b/mmcv/runner/hooks/optimizer.py similarity index 100% rename from mmcv/runner/hooks/optimizer_stepper.py rename to mmcv/runner/hooks/optimizer.py