add EmptyCacheHook and rename some modules

pull/19/head
Kai Chen 2018-10-07 17:30:00 +08:00
parent f4550cd319
commit ecfce392fa
4 changed files with 28 additions and 4 deletions

View File

@ -1,15 +1,16 @@
from .hook import Hook from .hook import Hook
from .checkpoint_saver import CheckpointHook from .checkpoint import CheckpointHook
from .closure import ClosureHook from .closure import ClosureHook
from .lr_updater import LrUpdaterHook from .lr_updater import LrUpdaterHook
from .optimizer_stepper import OptimizerHook from .optimizer import OptimizerHook
from .iter_timer import IterTimerHook from .iter_timer import IterTimerHook
from .sampler_seed import DistSamplerSeedHook from .sampler_seed import DistSamplerSeedHook
from .memory import EmptyCacheHook
from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook, from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook,
TensorboardLoggerHook) TensorboardLoggerHook)
__all__ = [ __all__ = [
'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'TextLoggerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook',
'PaviLoggerHook', 'TensorboardLoggerHook' 'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook'
] ]

View File

@ -0,0 +1,23 @@
import torch
from .hook import Hook
class EmptyCacheHook(Hook):
def __init__(self, before_epoch=False, after_epoch=True, after_iter=False):
self._before_epoch = before_epoch
self._after_epoch = after_epoch
self._after_iter = after_iter
def after_iter(self, runner):
if self._after_iter:
torch.cuda.empty_cache()
def before_epoch(self, runner):
if self._before_epoch:
torch.cuda.empty_cache()
def after_epoch(self, runner):
if self._after_epoch:
torch.cuda.empty_cache()