mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* [Fix] Fix lint * [Fix] Fix lint * Update mmengine/dist/utils.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
141 lines
4.6 KiB
Python
141 lines
4.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import time
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
|
|
from mmengine.device import is_cuda_available, is_musa_available
|
|
from mmengine.dist.utils import master_only
|
|
from mmengine.logging import MMLogger, print_log
|
|
|
|
|
|
class TimeCounter:
|
|
"""A tool that counts the average running time of a function or a method.
|
|
Users can use it as a decorator or context manager to calculate the average
|
|
running time of code blocks.
|
|
|
|
Args:
|
|
log_interval (int): The interval of logging. Defaults to 1.
|
|
warmup_interval (int): The interval of warmup. Defaults to 1.
|
|
with_sync (bool): Whether to synchronize cuda. Defaults to True.
|
|
tag (str, optional): Function tag. Used to distinguish between
|
|
different functions or methods being called. Defaults to None.
|
|
logger (MMLogger, optional): Formatted logger used to record messages.
|
|
Defaults to None.
|
|
|
|
Examples:
|
|
>>> import time
|
|
>>> from mmengine.utils.dl_utils import TimeCounter
|
|
>>> @TimeCounter()
|
|
... def fun1():
|
|
... time.sleep(0.1)
|
|
... fun1()
|
|
[fun1]-time per run averaged in the past 1 runs: 100.0 ms
|
|
|
|
>>> @@TimeCounter(log_interval=2, tag='fun')
|
|
... def fun2():
|
|
... time.sleep(0.2)
|
|
>>> for _ in range(3):
|
|
... fun2()
|
|
[fun]-time per run averaged in the past 2 runs: 200.0 ms
|
|
|
|
>>> with TimeCounter(tag='fun3'):
|
|
... time.sleep(0.3)
|
|
[fun3]-time per run averaged in the past 1 runs: 300.0 ms
|
|
"""
|
|
|
|
instance_dict: dict = dict()
|
|
|
|
log_interval: int
|
|
warmup_interval: int
|
|
logger: Optional[MMLogger]
|
|
__count: int
|
|
__pure_inf_time: float
|
|
|
|
def __new__(cls,
|
|
log_interval: int = 1,
|
|
warmup_interval: int = 1,
|
|
with_sync: bool = True,
|
|
tag: Optional[str] = None,
|
|
logger: Optional[MMLogger] = None):
|
|
assert warmup_interval >= 1
|
|
if tag is not None and tag in cls.instance_dict:
|
|
return cls.instance_dict[tag]
|
|
|
|
instance = super().__new__(cls)
|
|
cls.instance_dict[tag] = instance
|
|
|
|
instance.log_interval = log_interval
|
|
instance.warmup_interval = warmup_interval
|
|
instance.with_sync = with_sync # type: ignore
|
|
instance.tag = tag
|
|
instance.logger = logger
|
|
|
|
instance.__count = 0
|
|
instance.__pure_inf_time = 0.
|
|
instance.__start_time = 0.
|
|
|
|
return instance
|
|
|
|
@master_only
|
|
def __call__(self, fn):
|
|
if self.tag is None:
|
|
self.tag = fn.__name__
|
|
|
|
def wrapper(*args, **kwargs):
|
|
self.__count += 1
|
|
|
|
if self.with_sync:
|
|
if is_cuda_available():
|
|
torch.cuda.synchronize()
|
|
elif is_musa_available():
|
|
torch.musa.synchronize()
|
|
start_time = time.perf_counter()
|
|
|
|
result = fn(*args, **kwargs)
|
|
|
|
if self.with_sync:
|
|
if is_cuda_available():
|
|
torch.cuda.synchronize()
|
|
elif is_musa_available():
|
|
torch.musa.synchronize()
|
|
elapsed = time.perf_counter() - start_time
|
|
self.print_time(elapsed)
|
|
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
@master_only
|
|
def __enter__(self):
|
|
assert self.tag is not None, 'In order to clearly distinguish ' \
|
|
'printing information in different ' \
|
|
'contexts, please specify the ' \
|
|
'tag parameter'
|
|
|
|
self.__count += 1
|
|
|
|
if self.with_sync and torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
self.__start_time = time.perf_counter()
|
|
|
|
@master_only
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if self.with_sync and torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
elapsed = time.perf_counter() - self.__start_time
|
|
self.print_time(elapsed)
|
|
|
|
def print_time(self, elapsed: Union[int, float]) -> None:
|
|
"""Print times per count."""
|
|
if self.__count >= self.warmup_interval:
|
|
self.__pure_inf_time += elapsed
|
|
|
|
if self.__count % self.log_interval == 0:
|
|
times_per_count = 1000 * self.__pure_inf_time / (
|
|
self.__count - self.warmup_interval + 1)
|
|
print_log(
|
|
f'[{self.tag}]-time per run averaged in the past '
|
|
f'{self.__count} runs: {times_per_count:.1f} ms',
|
|
self.logger)
|