add EmptyCacheHook and rename some modules

pull/19/head
Kai Chen 2018-10-07 17:30:00 +08:00
parent f4550cd319
commit ecfce392fa
4 changed files with 28 additions and 4 deletions

View File

@ -1,15 +1,16 @@
from .hook import Hook
from .checkpoint_saver import CheckpointHook
from .checkpoint import CheckpointHook
from .closure import ClosureHook
from .lr_updater import LrUpdaterHook
from .optimizer_stepper import OptimizerHook
from .optimizer import OptimizerHook
from .iter_timer import IterTimerHook
from .sampler_seed import DistSamplerSeedHook
from .memory import EmptyCacheHook
from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook,
TensorboardLoggerHook)
__all__ = [
'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'TextLoggerHook',
'PaviLoggerHook', 'TensorboardLoggerHook'
'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook',
'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook'
]

View File

@ -0,0 +1,23 @@
import torch
from .hook import Hook
class EmptyCacheHook(Hook):
def __init__(self, before_epoch=False, after_epoch=True, after_iter=False):
self._before_epoch = before_epoch
self._after_epoch = after_epoch
self._after_iter = after_iter
def after_iter(self, runner):
if self._after_iter:
torch.cuda.empty_cache()
def before_epoch(self, runner):
if self._before_epoch:
torch.cuda.empty_cache()
def after_epoch(self, runner):
if self._after_epoch:
torch.cuda.empty_cache()