mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add Iter Timer Hook (#48)
* [Feature]: Add Part3 of Hooks * [Feature]: Add Hook * add iter timer hook * update test * [Fix]: Add docstring and type hint for base hook * fix mypy * improve doc coverage and merge main Co-authored-by: seuyou <3463423099@qq.com>
This commit is contained in:
parent
42448425b3
commit
1244e486ae
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .hook import Hook
|
||||
from .iter_timer_hook import IterTimerHook
|
||||
|
||||
__all__ = ['Hook']
|
||||
__all__ = ['Hook', 'IterTimerHook']
|
||||
|
58
mmengine/hooks/iter_timer_hook.py
Normal file
58
mmengine/hooks/iter_timer_hook.py
Normal file
@ -0,0 +1,58 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import time
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class IterTimerHook(Hook):
|
||||
"""A hook that logs the time spent during iteration.
|
||||
|
||||
Eg. ``data_time`` for loading data and ``time`` for a model train step.
|
||||
"""
|
||||
|
||||
def before_epoch(self, runner: object) -> None:
|
||||
"""Record time flag before start a epoch.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
"""
|
||||
self.t = time.time()
|
||||
|
||||
def before_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""Logging time for loading data and update the time flag.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample]): Data from dataloader.
|
||||
Defaults to None.
|
||||
"""
|
||||
# TODO: update for new logging system
|
||||
runner.log_buffer.update({ # type: ignore
|
||||
'data_time': time.time() - self.t
|
||||
})
|
||||
|
||||
def after_iter(self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""Logging time for a iteration and update the time flag.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample]): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (Sequence[BaseDataSample]): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
# TODO: update for new logging system
|
||||
runner.log_buffer.update({ # type: ignore
|
||||
'time': time.time() - self.t
|
||||
})
|
||||
self.t = time.time()
|
29
tests/test_hook/test_iter_timer_hook.py
Normal file
29
tests/test_hook/test_iter_timer_hook.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import Mock
|
||||
|
||||
from mmengine.hooks import IterTimerHook
|
||||
|
||||
|
||||
class TestIterTimerHook:
|
||||
|
||||
def test_before_epoch(self):
|
||||
Hook = IterTimerHook()
|
||||
Runner = Mock()
|
||||
Hook.before_epoch(Runner)
|
||||
assert isinstance(Hook.t, float)
|
||||
|
||||
def test_before_iter(self):
|
||||
Hook = IterTimerHook()
|
||||
Runner = Mock()
|
||||
Runner.log_buffer = dict()
|
||||
Hook.before_epoch(Runner)
|
||||
Hook.before_iter(Runner)
|
||||
assert 'data_time' in Runner.log_buffer
|
||||
|
||||
def test_after_iter(self):
|
||||
Hook = IterTimerHook()
|
||||
Runner = Mock()
|
||||
Runner.log_buffer = dict()
|
||||
Hook.before_epoch(Runner)
|
||||
Hook.after_iter(Runner)
|
||||
assert 'time' in Runner.log_buffer
|
Loading…
x
Reference in New Issue
Block a user