From c197bdf3591e617b106e1a0afb68ab033176c611 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?= <1286304229@qq.com> Date: Wed, 25 May 2022 10:55:07 +0800 Subject: [PATCH] [Feature] Profiling tools (#241) * Add profiling tools * fix docstr * fix docstr * update * fix bug * update * update * fix error * fix mypy * uodate * merge main * fix UT --- docs/en/api.rst | 10 ++ docs/zh_cn/tutorials/registry.md | 2 +- mmengine/dist/dist.py | 3 +- mmengine/logging/log_processor.py | 4 +- mmengine/registry/__init__.py | 4 +- mmengine/registry/root.py | 2 +- mmengine/runner/runner.py | 4 +- mmengine/utils/__init__.py | 3 + mmengine/utils/time_counter.py | 134 ++++++++++++++++++++++++++ tests/test_runner/test_runner.py | 4 +- tests/test_utils/test_time_counter.py | 55 +++++++++++ 11 files changed, 214 insertions(+), 11 deletions(-) create mode 100644 mmengine/utils/time_counter.py create mode 100644 tests/test_utils/test_time_counter.py diff --git a/docs/en/api.rst b/docs/en/api.rst index b12bf93c..93ff6db8 100644 --- a/docs/en/api.rst +++ b/docs/en/api.rst @@ -43,7 +43,17 @@ Logging .. automodule:: mmengine.logging :members: +Model +-------- +.. automodule:: mmengine.model + :members: + Visualization -------- .. automodule:: mmengine.visualization :members: + +Utils +-------- +.. automodule:: mmengine.utils + :members: diff --git a/docs/zh_cn/tutorials/registry.md b/docs/zh_cn/tutorials/registry.md index 201198da..5ed7b143 100644 --- a/docs/zh_cn/tutorials/registry.md +++ b/docs/zh_cn/tutorials/registry.md @@ -229,7 +229,7 @@ MMEngine 的注册器支持跨项目调用,即可以在一个项目中使用 - TASK_UTILS: 任务强相关的一些组件,如 `AnchorGenerator`, `BboxCoder` - VISUALIZERS: 管理绘制模块,如 `DetVisualizer` 可在图片上绘制预测框 - WRITERS: 存储训练日志的后端,如 `LocalWriter`, `TensorboardWriter` -- LOG_PROCESSOR: 控制日志的统计窗口和统计方法,默认使用 `LogProcessor`,如有特殊需求可自定义 `LogProcessor` +- LOG_PROCESSORS: 控制日志的统计窗口和统计方法,默认使用 `LogProcessor`,如有特殊需求可自定义 `LogProcessor` 下面我们以 OpenMMLab 开源项目为例介绍如何跨项目调用模块。 diff --git a/mmengine/dist/dist.py b/mmengine/dist/dist.py index 690338cb..f4b586ed 100644 --- a/mmengine/dist/dist.py +++ b/mmengine/dist/dist.py @@ -17,7 +17,8 @@ import mmengine from .utils import (get_world_size, get_rank, get_backend, get_dist_info, get_default_group, barrier, get_data_device, get_comm_device, cast_data_device) -from mmengine.utils import digit_version, TORCH_VERSION +from mmengine.utils.version_utils import digit_version +from mmengine.utils.parrots_wrapper import TORCH_VERSION def _get_reduce_op(name: str) -> torch_dist.ReduceOp: diff --git a/mmengine/logging/log_processor.py b/mmengine/logging/log_processor.py index 10710465..7d44a6b2 100644 --- a/mmengine/logging/log_processor.py +++ b/mmengine/logging/log_processor.py @@ -7,10 +7,10 @@ from typing import List, Optional, Tuple import torch from mmengine.device import get_max_cuda_memory -from mmengine.registry import LOG_PROCESSOR +from mmengine.registry import LOG_PROCESSORS -@LOG_PROCESSOR.register_module() +@LOG_PROCESSORS.register_module() # type: ignore class LogProcessor: """A log processor used to format log information collected from ``runner.message_hub.log_scalars``. diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py index 499ad94b..a67532b2 100644 --- a/mmengine/registry/__init__.py +++ b/mmengine/registry/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .default_scope import DefaultScope from .registry import Registry, build_from_cfg -from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOG_PROCESSOR, LOOPS, +from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS, @@ -13,6 +13,6 @@ __all__ = [ 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS', - 'LOG_PROCESSOR', 'DefaultScope', 'traverse_registry_tree', + 'LOG_PROCESSORS', 'DefaultScope', 'traverse_registry_tree', 'count_registered_modules' ] diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py index 3a64ae13..d3a47e47 100644 --- a/mmengine/registry/root.py +++ b/mmengine/registry/root.py @@ -47,4 +47,4 @@ VISUALIZERS = Registry('visualizer') VISBACKENDS = Registry('vis_backend') # manage logprocessor -LOG_PROCESSOR = Registry('log_processor') +LOG_PROCESSORS = Registry('log_processor') diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 3f025964..b6ddd0bf 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -30,7 +30,7 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS, VISUALIZERS, DefaultScope, count_registered_modules) -from mmengine.registry.root import LOG_PROCESSOR +from mmengine.registry.root import LOG_PROCESSORS from mmengine.utils import (TORCH_VERSION, digit_version, find_latest_checkpoint, is_list_of, set_multi_processing, symlink) @@ -1176,7 +1176,7 @@ class Runner: log_processor_cfg = copy.deepcopy(log_processor) # type: ignore if 'type' in log_processor_cfg: - log_processor = LOG_PROCESSOR.build(log_processor_cfg) + log_processor = LOG_PROCESSORS.build(log_processor_cfg) else: log_processor = LogProcessor(**log_processor_cfg) # type: ignore diff --git a/mmengine/utils/__init__.py b/mmengine/utils/__init__.py index 5a7ca682..01e30f86 100644 --- a/mmengine/utils/__init__.py +++ b/mmengine/utils/__init__.py @@ -15,6 +15,9 @@ from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist, from .setup_env import set_multi_processing from .version_utils import digit_version, get_git_hash +# TODO: creates intractable circular import issues +# from .time_counter import TimeCounter + __all__ = [ 'is_str', 'iter_cast', 'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list', diff --git a/mmengine/utils/time_counter.py b/mmengine/utils/time_counter.py new file mode 100644 index 00000000..056956ab --- /dev/null +++ b/mmengine/utils/time_counter.py @@ -0,0 +1,134 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import time +from typing import Optional, Union + +import torch + +from mmengine.dist.utils import master_only +from mmengine.logging import MMLogger, print_log + + +class TimeCounter: + """A tool that counts the average running time of a function or a method. + Users can use it as a decorator or context manager to calculate the average + running time of code blocks. + + Args: + log_interval (int): The interval of logging. Defaults to 1. + warmup_interval (int): The interval of warmup. Defaults to 1. + with_sync (bool): Whether to synchronize cuda. Defaults to True. + tag (str, optional): Function tag. Used to distinguish between + different functions or methods being called. Defaults to None. + logger (MMLogger, optional): Formatted logger used to record messages. + Defaults to None. + + Examples: + >>> import time + >>> from mmengine.utils import TimeCounter + >>> @TimeCounter() + ... def fun1(): + ... time.sleep(0.1) + ... fun1() + [fun1]-time per run averaged in the past 1 runs: 100.0 ms + + >>> @@TimeCounter(log_interval=2, tag='fun') + ... def fun2(): + ... time.sleep(0.2) + >>> for _ in range(3): + ... fun2() + [fun]-time per run averaged in the past 2 runs: 200.0 ms + + >>> with TimeCounter(tag='fun3'): + ... time.sleep(0.3) + [fun3]-time per run averaged in the past 1 runs: 300.0 ms + """ + + instance_dict: dict = dict() + + log_interval: int + warmup_interval: int + logger: Optional[MMLogger] + __count: int + __pure_inf_time: float + + def __new__(cls, + log_interval: int = 1, + warmup_interval: int = 1, + with_sync: bool = True, + tag: Optional[str] = None, + logger: Optional[MMLogger] = None): + assert warmup_interval >= 1 + if tag is not None and tag in cls.instance_dict: + return cls.instance_dict[tag] + + instance = super().__new__(cls) + cls.instance_dict[tag] = instance + + instance.log_interval = log_interval + instance.warmup_interval = warmup_interval + instance.with_sync = with_sync + instance.tag = tag + instance.logger = logger + + instance.__count = 0 + instance.__pure_inf_time = 0. + instance.__start_time = 0. + + return instance + + @master_only + def __call__(self, fn): + if self.tag is None: + self.tag = fn.__name__ + + def wrapper(*args, **kwargs): + self.__count += 1 + + if self.with_sync and torch.cuda.is_available(): + torch.cuda.synchronize() + start_time = time.perf_counter() + + result = fn(*args, **kwargs) + + if self.with_sync and torch.cuda.is_available(): + torch.cuda.synchronize() + + elapsed = time.perf_counter() - start_time + self.print_time(elapsed) + + return result + + return wrapper + + @master_only + def __enter__(self): + assert self.tag is not None, 'In order to clearly distinguish ' \ + 'printing information in different ' \ + 'contexts, please specify the ' \ + 'tag parameter' + + self.__count += 1 + + if self.with_sync and torch.cuda.is_available(): + torch.cuda.synchronize() + self.__start_time = time.perf_counter() + + @master_only + def __exit__(self, exc_type, exc_val, exc_tb): + if self.with_sync and torch.cuda.is_available(): + torch.cuda.synchronize() + elapsed = time.perf_counter() - self.__start_time + self.print_time(elapsed) + + def print_time(self, elapsed: Union[int, float]) -> None: + """print times per count.""" + if self.__count >= self.warmup_interval: + self.__pure_inf_time += elapsed + + if self.__count % self.log_interval == 0: + times_per_count = 1000 * self.__pure_inf_time / ( + self.__count - self.warmup_interval + 1) + print_log( + f'[{self.tag}]-time per run averaged in the past ' + f'{self.__count} runs: {times_per_count:.1f} ms', + self.logger) diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 6f96d13f..36ce728c 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -19,7 +19,7 @@ from mmengine.hooks import (DistSamplerSeedHook, Hook, IterTimerHook, from mmengine.hooks.checkpoint_hook import CheckpointHook from mmengine.logging import LogProcessor, MessageHub, MMLogger from mmengine.optim.scheduler import MultiStepLR, StepLR -from mmengine.registry import (DATASETS, HOOKS, LOG_PROCESSOR, LOOPS, METRICS, +from mmengine.registry import (DATASETS, HOOKS, LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS, Registry) from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop, @@ -172,7 +172,7 @@ class CustomTestLoop(BaseLoop): pass -@LOG_PROCESSOR.register_module() +@LOG_PROCESSORS.register_module() class CustomLogProcessor(LogProcessor): def __init__(self, window_size=10, by_epoch=True, custom_cfg=None): diff --git a/tests/test_utils/test_time_counter.py b/tests/test_utils/test_time_counter.py new file mode 100644 index 00000000..786c2ba3 --- /dev/null +++ b/tests/test_utils/test_time_counter.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import time +import unittest + +from mmengine.utils.time_counter import TimeCounter + + +class TestTimeCounter(unittest.TestCase): + + def test_decorate_timer(self): + + @TimeCounter() + def demo_fun(): + time.sleep(0.1) + + demo_fun() + + @TimeCounter() + def demo_fun(): + time.sleep(0.1) + + for _ in range(10): + demo_fun() + + @TimeCounter(log_interval=2, with_sync=False, tag='demo_fun1') + def demo_fun(): + time.sleep(0.1) + + demo_fun() + + # warmup_interval must be greater than 0 + with self.assertRaises(AssertionError): + + @TimeCounter(warmup_interval=0) + def demo_fun(): + time.sleep(0.1) + + def test_context_timer(self): + + # tag must be specified in context mode + with self.assertRaises(AssertionError): + with TimeCounter(): + time.sleep(0.1) + + # warmup_interval must be greater than 0 + with self.assertRaises(AssertionError): + with TimeCounter(warmup_interval=0, tag='func_1'): + time.sleep(0.1) + + with TimeCounter(tag='func_1'): + time.sleep(0.1) + + for _ in range(10): + with TimeCounter(log_interval=2, with_sync=False, tag='func_2'): + time.sleep(0.1)