q.yao f56a30025a
[Enhance] Add spell hook (#116)
* add spell hook

* fix code spell
2021-10-09 11:34:14 +08:00

121 lines
4.4 KiB
Python

import sys
import time
import warnings
from contextlib import contextmanager
import torch
class TimeCounter:
"""A tool for counting inference time of backends."""
names = dict()
file = sys.stdout
# Avoid instantiating every time
@classmethod
def count_time(cls, warmup=1, log_interval=1, with_sync=False):
def _register(func):
assert warmup >= 1
assert func.__name__ not in cls.names,\
'The registered function name cannot be repeated!'
# When adding on multiple functions, we need to ensure that the
# data does not interfere with each other
cls.names[func.__name__] = dict(
count=0,
execute_time=0,
log_interval=log_interval,
warmup=warmup,
with_sync=with_sync,
enable=False)
def fun(*args, **kwargs):
count = cls.names[func.__name__]['count']
execute_time = cls.names[func.__name__]['execute_time']
log_interval = cls.names[func.__name__]['log_interval']
warmup = cls.names[func.__name__]['warmup']
with_sync = cls.names[func.__name__]['with_sync']
enable = cls.names[func.__name__]['enable']
count += 1
cls.names[func.__name__]['count'] = count
if enable:
if with_sync and torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.perf_counter()
result = func(*args, **kwargs)
if enable:
if with_sync and torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time
if enable and count > warmup:
execute_time += elapsed
cls.names[func.__name__]['execute_time'] = execute_time
if (count - warmup) % log_interval == 0:
times_per_count = 1000 * execute_time / (
count - warmup)
msg = f'[{func.__name__}]-{count} times per count: '\
f'{times_per_count:.1f} ms, '\
f'{1000/times_per_count:.2f} FPS'
if cls.file != sys.stdout:
msg += '\n'
cls.file.write(msg)
cls.file.flush()
return result
return fun
return _register
@classmethod
@contextmanager
def activate(cls,
func_name: str = None,
warmup: int = 1,
log_interval: int = 1,
with_sync: bool = False,
file=sys.stdout):
"""Activate the time counter.
Args:
func_name (str): which function to activate, if not specified, all
registried function will be activated.
warmup (int): the warm up steps, default 1.
log_interval (int): interval between each log, default 1.
with_sync (bool): whether use cuda synchronize for time counting,
default False.
"""
assert warmup >= 1
if file != sys.stdout:
file = open(file, 'w+')
cls.file = file
if func_name is not None:
warnings.warn('func_name must be globally unique if you call '
'activate multiple times')
assert func_name in cls.names, '{} must be registried before '\
'setting params'.format(func_name)
cls.names[func_name]['warmup'] = warmup
cls.names[func_name]['log_interval'] = log_interval
cls.names[func_name]['with_sync'] = with_sync
cls.names[func_name]['enable'] = True
else:
for name in cls.names:
cls.names[name]['warmup'] = warmup
cls.names[name]['log_interval'] = log_interval
cls.names[name]['with_sync'] = with_sync
cls.names[name]['enable'] = True
yield
if file != sys.stdout:
cls.file.close()
if func_name is not None:
cls.names[func_name]['enable'] = False
else:
for name in cls.names:
cls.names[name]['enable'] = False