mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add Iter Timer Hook (#48)
* [Feature]: Add Part3 of Hooks * [Feature]: Add Hook * add iter timer hook * update test * [Fix]: Add docstring and type hint for base hook * fix mypy * improve doc coverage and merge main Co-authored-by: seuyou <3463423099@qq.com>
This commit is contained in:
parent
42448425b3
commit
1244e486ae
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
|
from .iter_timer_hook import IterTimerHook
|
||||||
|
|
||||||
__all__ = ['Hook']
|
__all__ = ['Hook', 'IterTimerHook']
|
||||||
|
58
mmengine/hooks/iter_timer_hook.py
Normal file
58
mmengine/hooks/iter_timer_hook.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import time
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
from mmengine.data import BaseDataSample
|
||||||
|
from mmengine.registry import HOOKS
|
||||||
|
from .hook import Hook
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class IterTimerHook(Hook):
|
||||||
|
"""A hook that logs the time spent during iteration.
|
||||||
|
|
||||||
|
Eg. ``data_time`` for loading data and ``time`` for a model train step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def before_epoch(self, runner: object) -> None:
|
||||||
|
"""Record time flag before start a epoch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (object): The runner of the training process.
|
||||||
|
"""
|
||||||
|
self.t = time.time()
|
||||||
|
|
||||||
|
def before_iter(
|
||||||
|
self,
|
||||||
|
runner: object,
|
||||||
|
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||||
|
"""Logging time for loading data and update the time flag.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (object): The runner of the training process.
|
||||||
|
data_batch (Sequence[BaseDataSample]): Data from dataloader.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
# TODO: update for new logging system
|
||||||
|
runner.log_buffer.update({ # type: ignore
|
||||||
|
'data_time': time.time() - self.t
|
||||||
|
})
|
||||||
|
|
||||||
|
def after_iter(self,
|
||||||
|
runner: object,
|
||||||
|
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||||
|
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||||
|
"""Logging time for a iteration and update the time flag.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (object): The runner of the training process.
|
||||||
|
data_batch (Sequence[BaseDataSample]): Data from dataloader.
|
||||||
|
Defaults to None.
|
||||||
|
outputs (Sequence[BaseDataSample]): Outputs from model.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
# TODO: update for new logging system
|
||||||
|
runner.log_buffer.update({ # type: ignore
|
||||||
|
'time': time.time() - self.t
|
||||||
|
})
|
||||||
|
self.t = time.time()
|
29
tests/test_hook/test_iter_timer_hook.py
Normal file
29
tests/test_hook/test_iter_timer_hook.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from mmengine.hooks import IterTimerHook
|
||||||
|
|
||||||
|
|
||||||
|
class TestIterTimerHook:
|
||||||
|
|
||||||
|
def test_before_epoch(self):
|
||||||
|
Hook = IterTimerHook()
|
||||||
|
Runner = Mock()
|
||||||
|
Hook.before_epoch(Runner)
|
||||||
|
assert isinstance(Hook.t, float)
|
||||||
|
|
||||||
|
def test_before_iter(self):
|
||||||
|
Hook = IterTimerHook()
|
||||||
|
Runner = Mock()
|
||||||
|
Runner.log_buffer = dict()
|
||||||
|
Hook.before_epoch(Runner)
|
||||||
|
Hook.before_iter(Runner)
|
||||||
|
assert 'data_time' in Runner.log_buffer
|
||||||
|
|
||||||
|
def test_after_iter(self):
|
||||||
|
Hook = IterTimerHook()
|
||||||
|
Runner = Mock()
|
||||||
|
Runner.log_buffer = dict()
|
||||||
|
Hook.before_epoch(Runner)
|
||||||
|
Hook.after_iter(Runner)
|
||||||
|
assert 'time' in Runner.log_buffer
|
Loading…
x
Reference in New Issue
Block a user