mirror of https://github.com/open-mmlab/mmcv.git
add EmptyCacheHook and rename some modules
parent
f4550cd319
commit
ecfce392fa
|
@ -1,15 +1,16 @@
|
||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
from .checkpoint_saver import CheckpointHook
|
from .checkpoint import CheckpointHook
|
||||||
from .closure import ClosureHook
|
from .closure import ClosureHook
|
||||||
from .lr_updater import LrUpdaterHook
|
from .lr_updater import LrUpdaterHook
|
||||||
from .optimizer_stepper import OptimizerHook
|
from .optimizer import OptimizerHook
|
||||||
from .iter_timer import IterTimerHook
|
from .iter_timer import IterTimerHook
|
||||||
from .sampler_seed import DistSamplerSeedHook
|
from .sampler_seed import DistSamplerSeedHook
|
||||||
|
from .memory import EmptyCacheHook
|
||||||
from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook,
|
from .logger import (LoggerHook, TextLoggerHook, PaviLoggerHook,
|
||||||
TensorboardLoggerHook)
|
TensorboardLoggerHook)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
|
'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook',
|
||||||
'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'TextLoggerHook',
|
'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook',
|
||||||
'PaviLoggerHook', 'TensorboardLoggerHook'
|
'TextLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook'
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .hook import Hook
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyCacheHook(Hook):
|
||||||
|
|
||||||
|
def __init__(self, before_epoch=False, after_epoch=True, after_iter=False):
|
||||||
|
self._before_epoch = before_epoch
|
||||||
|
self._after_epoch = after_epoch
|
||||||
|
self._after_iter = after_iter
|
||||||
|
|
||||||
|
def after_iter(self, runner):
|
||||||
|
if self._after_iter:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def before_epoch(self, runner):
|
||||||
|
if self._before_epoch:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def after_epoch(self, runner):
|
||||||
|
if self._after_epoch:
|
||||||
|
torch.cuda.empty_cache()
|
Loading…
Reference in New Issue