mmengine/mmengine/hooks/empty_cache_hook.py
Mashiro 8770c6c7fc
[Refactor] Refactor data flow to make the interface more natural (#468)
* [Refactor]: modify interface of Visualizer.add_datasample (#365)

* [Refactor] Refactor data flow: refine `data_preprocessor`. (#359)

* refine data_preprocessor

* remove unused BATCH_DATA alias

* Fix type hints

* rename move_data to cast_data

* [Refactor] Refactor data flow: collate data in `collate_fn` of `DataLoader`  (#323)

* acollate data in dataloader

* fix docstring

* refine comment

* fix as comment

* refactor default collate and psedo collate

* foramt test file

* fix docstring

* fix as comment

* rename elem to data_item

* minor fix

* fix as comment

* [Refactor] Refactor data flow: `data_batch` argument of `Evaluator.process is a `dict` (#360)

* refine evaluator and metric

* compatible with new default collate

* replace default collate with pseudo

* Handle data_batch in metric

* fix unit test

* fix unit test

* fix unit test

* minor refine

* make data_batch optional

make data_batch optional

* rename outputs to predictions

* fix ut

* rename predictions to outputs

* fix docstring

* fix docstring

* fix unit test

* make outputs and data_batch to kwargs

* fix unit test

* keep signature of metric

* fix ut

* rename pred_sample arguments to data_sample(Visualizer)

* fix loop and ut

* [refactor]: Refactor model dataflow (#398)

* [Refactor] Refactor data flow: refine `data_preprocessor`. (#359)

* refine data_preprocessor

* remove unused BATCH_DATA alias

* Fix type hints

* rename move_data to cast_data

* refactor model data flow

tmp_commt

tmp commit

* make val_cfg and test_cfg optional

* roll back runner

* pass test mmdet

* fix as comment

fix as comment

fix ci in DataPreprocessor

* fix ut

* fix ut

* fix rebase main

* [Fix]: Fix test val ddp (#462)

* [Fix] Fix docstring and type hint of data flow (#463)

* Fix docstring of data flow

* change signature of hook

* fix unit test

* resolve conflicts

* fix lint
2022-08-24 22:04:55 +08:00

73 lines
2.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Union
import torch
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Union[dict, tuple, list]]
@HOOKS.register_module()
class EmptyCacheHook(Hook):
"""Releases all unoccupied cached GPU memory during the process of
training.
Args:
before_epoch (bool): Whether to release cache before an epoch. Defaults
to False.
after_epoch (bool): Whether to release cache after an epoch. Defaults
to True.
after_iter (bool): Whether to release cache after an iteration.
Defaults to False.
"""
priority = 'NORMAL'
def __init__(self,
before_epoch: bool = False,
after_epoch: bool = True,
after_iter: bool = False) -> None:
self._do_before_epoch = before_epoch
self._do_after_epoch = after_epoch
self._do_after_iter = after_iter
def _after_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Union[dict, Sequence]] = None,
mode: str = 'train') -> None:
"""Empty cache after an iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the loop.
data_batch (dict or tuple or list, optional): Data from dataloader.
outputs (dict or sequence, optional): Outputs from model.
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._do_after_iter:
torch.cuda.empty_cache()
def _before_epoch(self, runner, mode: str = 'train') -> None:
"""Empty cache before an epoch.
Args:
runner (Runner): The runner of the training process.
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._do_before_epoch:
torch.cuda.empty_cache()
def _after_epoch(self, runner, mode: str = 'train') -> None:
"""Empty cache after an epoch.
Args:
runner (Runner): The runner of the training process.
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._do_after_epoch:
torch.cuda.empty_cache()