mirror of https://github.com/open-mmlab/mmcv.git
add EmptyCacheHook and rename some modules
parent
f4550cd319
commit
ecfce392fa
|
@ -1,15 +1,16 @@
|
|||
from .hook import Hook
|
||||
from .checkpoint_saver import CheckpointHook
|
||||
from .checkpoint import CheckpointHook
|
||||
from .closure import ClosureHook
|
||||
from .lr_updater import LrUpdaterHook
|
||||
from .optimizer_stepper import OptimizerHook
|
||||
from .optimizer import OptimizerHook
|
||||
from .iter_timer import IterTimerHook
|
||||
from .sampler_seed import DistSamplerSeedHook
|
||||
from .memory import EmptyCacheHook
|
||||
from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook,
|
||||
TensorboardLoggerHook)
|
||||
|
||||
__all__ = [
|
||||
'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
|
||||
'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'TextLoggerHook',
|
||||
'PaviLoggerHook', 'TensorboardLoggerHook'
|
||||
'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook',
|
||||
'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
import torch
|
||||
|
||||
from .hook import Hook
|
||||
|
||||
|
||||
class EmptyCacheHook(Hook):
|
||||
|
||||
def __init__(self, before_epoch=False, after_epoch=True, after_iter=False):
|
||||
self._before_epoch = before_epoch
|
||||
self._after_epoch = after_epoch
|
||||
self._after_iter = after_iter
|
||||
|
||||
def after_iter(self, runner):
|
||||
if self._after_iter:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def before_epoch(self, runner):
|
||||
if self._before_epoch:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def after_epoch(self, runner):
|
||||
if self._after_epoch:
|
||||
torch.cuda.empty_cache()
|
Loading…
Reference in New Issue