[Feature] Profiling tools (#241)
* Add profiling tools * fix docstr * fix docstr * update * fix bug * update * update * fix error * fix mypy * uodate * merge main * fix UTpull/261/head
parent
8d3bd4dfef
commit
c197bdf359
|
@ -43,7 +43,17 @@ Logging
|
|||
.. automodule:: mmengine.logging
|
||||
:members:
|
||||
|
||||
Model
|
||||
--------
|
||||
.. automodule:: mmengine.model
|
||||
:members:
|
||||
|
||||
Visualization
|
||||
--------
|
||||
.. automodule:: mmengine.visualization
|
||||
:members:
|
||||
|
||||
Utils
|
||||
--------
|
||||
.. automodule:: mmengine.utils
|
||||
:members:
|
||||
|
|
|
@ -229,7 +229,7 @@ MMEngine 的注册器支持跨项目调用,即可以在一个项目中使用
|
|||
- TASK_UTILS: 任务强相关的一些组件,如 `AnchorGenerator`, `BboxCoder`
|
||||
- VISUALIZERS: 管理绘制模块,如 `DetVisualizer` 可在图片上绘制预测框
|
||||
- WRITERS: 存储训练日志的后端,如 `LocalWriter`, `TensorboardWriter`
|
||||
- LOG_PROCESSOR: 控制日志的统计窗口和统计方法,默认使用 `LogProcessor`,如有特殊需求可自定义 `LogProcessor`
|
||||
- LOG_PROCESSORS: 控制日志的统计窗口和统计方法,默认使用 `LogProcessor`,如有特殊需求可自定义 `LogProcessor`
|
||||
|
||||
下面我们以 OpenMMLab 开源项目为例介绍如何跨项目调用模块。
|
||||
|
||||
|
|
|
@ -17,7 +17,8 @@ import mmengine
|
|||
from .utils import (get_world_size, get_rank, get_backend, get_dist_info,
|
||||
get_default_group, barrier, get_data_device,
|
||||
get_comm_device, cast_data_device)
|
||||
from mmengine.utils import digit_version, TORCH_VERSION
|
||||
from mmengine.utils.version_utils import digit_version
|
||||
from mmengine.utils.parrots_wrapper import TORCH_VERSION
|
||||
|
||||
|
||||
def _get_reduce_op(name: str) -> torch_dist.ReduceOp:
|
||||
|
|
|
@ -7,10 +7,10 @@ from typing import List, Optional, Tuple
|
|||
import torch
|
||||
|
||||
from mmengine.device import get_max_cuda_memory
|
||||
from mmengine.registry import LOG_PROCESSOR
|
||||
from mmengine.registry import LOG_PROCESSORS
|
||||
|
||||
|
||||
@LOG_PROCESSOR.register_module()
|
||||
@LOG_PROCESSORS.register_module() # type: ignore
|
||||
class LogProcessor:
|
||||
"""A log processor used to format log information collected from
|
||||
``runner.message_hub.log_scalars``.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .default_scope import DefaultScope
|
||||
from .registry import Registry, build_from_cfg
|
||||
from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOG_PROCESSOR, LOOPS,
|
||||
from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOG_PROCESSORS, LOOPS,
|
||||
METRICS, MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS,
|
||||
OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS,
|
||||
TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS,
|
||||
|
@ -13,6 +13,6 @@ __all__ = [
|
|||
'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
|
||||
'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS',
|
||||
'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS',
|
||||
'LOG_PROCESSOR', 'DefaultScope', 'traverse_registry_tree',
|
||||
'LOG_PROCESSORS', 'DefaultScope', 'traverse_registry_tree',
|
||||
'count_registered_modules'
|
||||
]
|
||||
|
|
|
@ -47,4 +47,4 @@ VISUALIZERS = Registry('visualizer')
|
|||
VISBACKENDS = Registry('vis_backend')
|
||||
|
||||
# manage logprocessor
|
||||
LOG_PROCESSOR = Registry('log_processor')
|
||||
LOG_PROCESSORS = Registry('log_processor')
|
||||
|
|
|
@ -30,7 +30,7 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
|
|||
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
|
||||
VISUALIZERS, DefaultScope,
|
||||
count_registered_modules)
|
||||
from mmengine.registry.root import LOG_PROCESSOR
|
||||
from mmengine.registry.root import LOG_PROCESSORS
|
||||
from mmengine.utils import (TORCH_VERSION, digit_version,
|
||||
find_latest_checkpoint, is_list_of,
|
||||
set_multi_processing, symlink)
|
||||
|
@ -1176,7 +1176,7 @@ class Runner:
|
|||
log_processor_cfg = copy.deepcopy(log_processor) # type: ignore
|
||||
|
||||
if 'type' in log_processor_cfg:
|
||||
log_processor = LOG_PROCESSOR.build(log_processor_cfg)
|
||||
log_processor = LOG_PROCESSORS.build(log_processor_cfg)
|
||||
else:
|
||||
log_processor = LogProcessor(**log_processor_cfg) # type: ignore
|
||||
|
||||
|
|
|
@ -15,6 +15,9 @@ from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
|
|||
from .setup_env import set_multi_processing
|
||||
from .version_utils import digit_version, get_git_hash
|
||||
|
||||
# TODO: creates intractable circular import issues
|
||||
# from .time_counter import TimeCounter
|
||||
|
||||
__all__ = [
|
||||
'is_str', 'iter_cast', 'list_cast', 'tuple_cast', 'is_seq_of',
|
||||
'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list',
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from mmengine.dist.utils import master_only
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
|
||||
|
||||
class TimeCounter:
|
||||
"""A tool that counts the average running time of a function or a method.
|
||||
Users can use it as a decorator or context manager to calculate the average
|
||||
running time of code blocks.
|
||||
|
||||
Args:
|
||||
log_interval (int): The interval of logging. Defaults to 1.
|
||||
warmup_interval (int): The interval of warmup. Defaults to 1.
|
||||
with_sync (bool): Whether to synchronize cuda. Defaults to True.
|
||||
tag (str, optional): Function tag. Used to distinguish between
|
||||
different functions or methods being called. Defaults to None.
|
||||
logger (MMLogger, optional): Formatted logger used to record messages.
|
||||
Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> import time
|
||||
>>> from mmengine.utils import TimeCounter
|
||||
>>> @TimeCounter()
|
||||
... def fun1():
|
||||
... time.sleep(0.1)
|
||||
... fun1()
|
||||
[fun1]-time per run averaged in the past 1 runs: 100.0 ms
|
||||
|
||||
>>> @@TimeCounter(log_interval=2, tag='fun')
|
||||
... def fun2():
|
||||
... time.sleep(0.2)
|
||||
>>> for _ in range(3):
|
||||
... fun2()
|
||||
[fun]-time per run averaged in the past 2 runs: 200.0 ms
|
||||
|
||||
>>> with TimeCounter(tag='fun3'):
|
||||
... time.sleep(0.3)
|
||||
[fun3]-time per run averaged in the past 1 runs: 300.0 ms
|
||||
"""
|
||||
|
||||
instance_dict: dict = dict()
|
||||
|
||||
log_interval: int
|
||||
warmup_interval: int
|
||||
logger: Optional[MMLogger]
|
||||
__count: int
|
||||
__pure_inf_time: float
|
||||
|
||||
def __new__(cls,
|
||||
log_interval: int = 1,
|
||||
warmup_interval: int = 1,
|
||||
with_sync: bool = True,
|
||||
tag: Optional[str] = None,
|
||||
logger: Optional[MMLogger] = None):
|
||||
assert warmup_interval >= 1
|
||||
if tag is not None and tag in cls.instance_dict:
|
||||
return cls.instance_dict[tag]
|
||||
|
||||
instance = super().__new__(cls)
|
||||
cls.instance_dict[tag] = instance
|
||||
|
||||
instance.log_interval = log_interval
|
||||
instance.warmup_interval = warmup_interval
|
||||
instance.with_sync = with_sync
|
||||
instance.tag = tag
|
||||
instance.logger = logger
|
||||
|
||||
instance.__count = 0
|
||||
instance.__pure_inf_time = 0.
|
||||
instance.__start_time = 0.
|
||||
|
||||
return instance
|
||||
|
||||
@master_only
|
||||
def __call__(self, fn):
|
||||
if self.tag is None:
|
||||
self.tag = fn.__name__
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
self.__count += 1
|
||||
|
||||
if self.with_sync and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
result = fn(*args, **kwargs)
|
||||
|
||||
if self.with_sync and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
self.print_time(elapsed)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
@master_only
|
||||
def __enter__(self):
|
||||
assert self.tag is not None, 'In order to clearly distinguish ' \
|
||||
'printing information in different ' \
|
||||
'contexts, please specify the ' \
|
||||
'tag parameter'
|
||||
|
||||
self.__count += 1
|
||||
|
||||
if self.with_sync and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
self.__start_time = time.perf_counter()
|
||||
|
||||
@master_only
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.with_sync and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - self.__start_time
|
||||
self.print_time(elapsed)
|
||||
|
||||
def print_time(self, elapsed: Union[int, float]) -> None:
|
||||
"""print times per count."""
|
||||
if self.__count >= self.warmup_interval:
|
||||
self.__pure_inf_time += elapsed
|
||||
|
||||
if self.__count % self.log_interval == 0:
|
||||
times_per_count = 1000 * self.__pure_inf_time / (
|
||||
self.__count - self.warmup_interval + 1)
|
||||
print_log(
|
||||
f'[{self.tag}]-time per run averaged in the past '
|
||||
f'{self.__count} runs: {times_per_count:.1f} ms',
|
||||
self.logger)
|
|
@ -19,7 +19,7 @@ from mmengine.hooks import (DistSamplerSeedHook, Hook, IterTimerHook,
|
|||
from mmengine.hooks.checkpoint_hook import CheckpointHook
|
||||
from mmengine.logging import LogProcessor, MessageHub, MMLogger
|
||||
from mmengine.optim.scheduler import MultiStepLR, StepLR
|
||||
from mmengine.registry import (DATASETS, HOOKS, LOG_PROCESSOR, LOOPS, METRICS,
|
||||
from mmengine.registry import (DATASETS, HOOKS, LOG_PROCESSORS, LOOPS, METRICS,
|
||||
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
|
||||
Registry)
|
||||
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
|
||||
|
@ -172,7 +172,7 @@ class CustomTestLoop(BaseLoop):
|
|||
pass
|
||||
|
||||
|
||||
@LOG_PROCESSOR.register_module()
|
||||
@LOG_PROCESSORS.register_module()
|
||||
class CustomLogProcessor(LogProcessor):
|
||||
|
||||
def __init__(self, window_size=10, by_epoch=True, custom_cfg=None):
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from mmengine.utils.time_counter import TimeCounter
|
||||
|
||||
|
||||
class TestTimeCounter(unittest.TestCase):
|
||||
|
||||
def test_decorate_timer(self):
|
||||
|
||||
@TimeCounter()
|
||||
def demo_fun():
|
||||
time.sleep(0.1)
|
||||
|
||||
demo_fun()
|
||||
|
||||
@TimeCounter()
|
||||
def demo_fun():
|
||||
time.sleep(0.1)
|
||||
|
||||
for _ in range(10):
|
||||
demo_fun()
|
||||
|
||||
@TimeCounter(log_interval=2, with_sync=False, tag='demo_fun1')
|
||||
def demo_fun():
|
||||
time.sleep(0.1)
|
||||
|
||||
demo_fun()
|
||||
|
||||
# warmup_interval must be greater than 0
|
||||
with self.assertRaises(AssertionError):
|
||||
|
||||
@TimeCounter(warmup_interval=0)
|
||||
def demo_fun():
|
||||
time.sleep(0.1)
|
||||
|
||||
def test_context_timer(self):
|
||||
|
||||
# tag must be specified in context mode
|
||||
with self.assertRaises(AssertionError):
|
||||
with TimeCounter():
|
||||
time.sleep(0.1)
|
||||
|
||||
# warmup_interval must be greater than 0
|
||||
with self.assertRaises(AssertionError):
|
||||
with TimeCounter(warmup_interval=0, tag='func_1'):
|
||||
time.sleep(0.1)
|
||||
|
||||
with TimeCounter(tag='func_1'):
|
||||
time.sleep(0.1)
|
||||
|
||||
for _ in range(10):
|
||||
with TimeCounter(log_interval=2, with_sync=False, tag='func_2'):
|
||||
time.sleep(0.1)
|
Loading…
Reference in New Issue