[Feature] Profiling tools (#241)

* Add profiling tools

* fix docstr

* fix docstr

* update

* fix bug

* update

* update

* fix error

* fix mypy

* uodate

* merge main

* fix UT
pull/261/head
Haian Huang(深度眸) 2022-05-25 10:55:07 +08:00 committed by GitHub
parent 8d3bd4dfef
commit c197bdf359
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 214 additions and 11 deletions

View File

@ -43,7 +43,17 @@ Logging
.. automodule:: mmengine.logging
:members:
Model
--------
.. automodule:: mmengine.model
:members:
Visualization
--------
.. automodule:: mmengine.visualization
:members:
Utils
--------
.. automodule:: mmengine.utils
:members:

View File

@ -229,7 +229,7 @@ MMEngine 的注册器支持跨项目调用,即可以在一个项目中使用
- TASK_UTILS: 任务强相关的一些组件,如 `AnchorGenerator`, `BboxCoder`
- VISUALIZERS: 管理绘制模块,如 `DetVisualizer` 可在图片上绘制预测框
- WRITERS: 存储训练日志的后端,如 `LocalWriter`, `TensorboardWriter`
- LOG_PROCESSOR: 控制日志的统计窗口和统计方法,默认使用 `LogProcessor`,如有特殊需求可自定义 `LogProcessor`
- LOG_PROCESSORS: 控制日志的统计窗口和统计方法,默认使用 `LogProcessor`,如有特殊需求可自定义 `LogProcessor`
下面我们以 OpenMMLab 开源项目为例介绍如何跨项目调用模块。

View File

@ -17,7 +17,8 @@ import mmengine
from .utils import (get_world_size, get_rank, get_backend, get_dist_info,
get_default_group, barrier, get_data_device,
get_comm_device, cast_data_device)
from mmengine.utils import digit_version, TORCH_VERSION
from mmengine.utils.version_utils import digit_version
from mmengine.utils.parrots_wrapper import TORCH_VERSION
def _get_reduce_op(name: str) -> torch_dist.ReduceOp:

View File

@ -7,10 +7,10 @@ from typing import List, Optional, Tuple
import torch
from mmengine.device import get_max_cuda_memory
from mmengine.registry import LOG_PROCESSOR
from mmengine.registry import LOG_PROCESSORS
@LOG_PROCESSOR.register_module()
@LOG_PROCESSORS.register_module() # type: ignore
class LogProcessor:
"""A log processor used to format log information collected from
``runner.message_hub.log_scalars``.

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .default_scope import DefaultScope
from .registry import Registry, build_from_cfg
from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOG_PROCESSOR, LOOPS,
from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOG_PROCESSORS, LOOPS,
METRICS, MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS,
OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS,
TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS,
@ -13,6 +13,6 @@ __all__ = [
'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS',
'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS',
'LOG_PROCESSOR', 'DefaultScope', 'traverse_registry_tree',
'LOG_PROCESSORS', 'DefaultScope', 'traverse_registry_tree',
'count_registered_modules'
]

View File

@ -47,4 +47,4 @@ VISUALIZERS = Registry('visualizer')
VISBACKENDS = Registry('vis_backend')
# manage logprocessor
LOG_PROCESSOR = Registry('log_processor')
LOG_PROCESSORS = Registry('log_processor')

View File

@ -30,7 +30,7 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
VISUALIZERS, DefaultScope,
count_registered_modules)
from mmengine.registry.root import LOG_PROCESSOR
from mmengine.registry.root import LOG_PROCESSORS
from mmengine.utils import (TORCH_VERSION, digit_version,
find_latest_checkpoint, is_list_of,
set_multi_processing, symlink)
@ -1176,7 +1176,7 @@ class Runner:
log_processor_cfg = copy.deepcopy(log_processor) # type: ignore
if 'type' in log_processor_cfg:
log_processor = LOG_PROCESSOR.build(log_processor_cfg)
log_processor = LOG_PROCESSORS.build(log_processor_cfg)
else:
log_processor = LogProcessor(**log_processor_cfg) # type: ignore

View File

@ -15,6 +15,9 @@ from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
from .setup_env import set_multi_processing
from .version_utils import digit_version, get_git_hash
# TODO: creates intractable circular import issues
# from .time_counter import TimeCounter
__all__ = [
'is_str', 'iter_cast', 'list_cast', 'tuple_cast', 'is_seq_of',
'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list',

View File

@ -0,0 +1,134 @@
# Copyright (c) OpenMMLab. All rights reserved.
import time
from typing import Optional, Union
import torch
from mmengine.dist.utils import master_only
from mmengine.logging import MMLogger, print_log
class TimeCounter:
"""A tool that counts the average running time of a function or a method.
Users can use it as a decorator or context manager to calculate the average
running time of code blocks.
Args:
log_interval (int): The interval of logging. Defaults to 1.
warmup_interval (int): The interval of warmup. Defaults to 1.
with_sync (bool): Whether to synchronize cuda. Defaults to True.
tag (str, optional): Function tag. Used to distinguish between
different functions or methods being called. Defaults to None.
logger (MMLogger, optional): Formatted logger used to record messages.
Defaults to None.
Examples:
>>> import time
>>> from mmengine.utils import TimeCounter
>>> @TimeCounter()
... def fun1():
... time.sleep(0.1)
... fun1()
[fun1]-time per run averaged in the past 1 runs: 100.0 ms
>>> @@TimeCounter(log_interval=2, tag='fun')
... def fun2():
... time.sleep(0.2)
>>> for _ in range(3):
... fun2()
[fun]-time per run averaged in the past 2 runs: 200.0 ms
>>> with TimeCounter(tag='fun3'):
... time.sleep(0.3)
[fun3]-time per run averaged in the past 1 runs: 300.0 ms
"""
instance_dict: dict = dict()
log_interval: int
warmup_interval: int
logger: Optional[MMLogger]
__count: int
__pure_inf_time: float
def __new__(cls,
log_interval: int = 1,
warmup_interval: int = 1,
with_sync: bool = True,
tag: Optional[str] = None,
logger: Optional[MMLogger] = None):
assert warmup_interval >= 1
if tag is not None and tag in cls.instance_dict:
return cls.instance_dict[tag]
instance = super().__new__(cls)
cls.instance_dict[tag] = instance
instance.log_interval = log_interval
instance.warmup_interval = warmup_interval
instance.with_sync = with_sync
instance.tag = tag
instance.logger = logger
instance.__count = 0
instance.__pure_inf_time = 0.
instance.__start_time = 0.
return instance
@master_only
def __call__(self, fn):
if self.tag is None:
self.tag = fn.__name__
def wrapper(*args, **kwargs):
self.__count += 1
if self.with_sync and torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.perf_counter()
result = fn(*args, **kwargs)
if self.with_sync and torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time
self.print_time(elapsed)
return result
return wrapper
@master_only
def __enter__(self):
assert self.tag is not None, 'In order to clearly distinguish ' \
'printing information in different ' \
'contexts, please specify the ' \
'tag parameter'
self.__count += 1
if self.with_sync and torch.cuda.is_available():
torch.cuda.synchronize()
self.__start_time = time.perf_counter()
@master_only
def __exit__(self, exc_type, exc_val, exc_tb):
if self.with_sync and torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = time.perf_counter() - self.__start_time
self.print_time(elapsed)
def print_time(self, elapsed: Union[int, float]) -> None:
"""print times per count."""
if self.__count >= self.warmup_interval:
self.__pure_inf_time += elapsed
if self.__count % self.log_interval == 0:
times_per_count = 1000 * self.__pure_inf_time / (
self.__count - self.warmup_interval + 1)
print_log(
f'[{self.tag}]-time per run averaged in the past '
f'{self.__count} runs: {times_per_count:.1f} ms',
self.logger)

View File

@ -19,7 +19,7 @@ from mmengine.hooks import (DistSamplerSeedHook, Hook, IterTimerHook,
from mmengine.hooks.checkpoint_hook import CheckpointHook
from mmengine.logging import LogProcessor, MessageHub, MMLogger
from mmengine.optim.scheduler import MultiStepLR, StepLR
from mmengine.registry import (DATASETS, HOOKS, LOG_PROCESSOR, LOOPS, METRICS,
from mmengine.registry import (DATASETS, HOOKS, LOG_PROCESSORS, LOOPS, METRICS,
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
Registry)
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
@ -172,7 +172,7 @@ class CustomTestLoop(BaseLoop):
pass
@LOG_PROCESSOR.register_module()
@LOG_PROCESSORS.register_module()
class CustomLogProcessor(LogProcessor):
def __init__(self, window_size=10, by_epoch=True, custom_cfg=None):

View File

@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
import time
import unittest
from mmengine.utils.time_counter import TimeCounter
class TestTimeCounter(unittest.TestCase):
def test_decorate_timer(self):
@TimeCounter()
def demo_fun():
time.sleep(0.1)
demo_fun()
@TimeCounter()
def demo_fun():
time.sleep(0.1)
for _ in range(10):
demo_fun()
@TimeCounter(log_interval=2, with_sync=False, tag='demo_fun1')
def demo_fun():
time.sleep(0.1)
demo_fun()
# warmup_interval must be greater than 0
with self.assertRaises(AssertionError):
@TimeCounter(warmup_interval=0)
def demo_fun():
time.sleep(0.1)
def test_context_timer(self):
# tag must be specified in context mode
with self.assertRaises(AssertionError):
with TimeCounter():
time.sleep(0.1)
# warmup_interval must be greater than 0
with self.assertRaises(AssertionError):
with TimeCounter(warmup_interval=0, tag='func_1'):
time.sleep(0.1)
with TimeCounter(tag='func_1'):
time.sleep(0.1)
for _ in range(10):
with TimeCounter(log_interval=2, with_sync=False, tag='func_2'):
time.sleep(0.1)