mirror of https://github.com/alibaba/EasyCV.git
61 lines
1.9 KiB
Python
61 lines
1.9 KiB
Python
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||
|
import contextlib
|
||
|
import sys
|
||
|
import time
|
||
|
|
||
|
import torch
|
||
|
|
||
|
if sys.version_info >= (3, 7):
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def profile_time(trace_name,
|
||
|
name,
|
||
|
enabled=True,
|
||
|
stream=None,
|
||
|
end_stream=None):
|
||
|
"""Print time spent by CPU and GPU.
|
||
|
|
||
|
Useful as a temporary context manager to find sweet spots of
|
||
|
code suitable for async implementation.
|
||
|
|
||
|
"""
|
||
|
if (not enabled) or not torch.cuda.is_available():
|
||
|
yield
|
||
|
return
|
||
|
stream = stream if stream else torch.cuda.current_stream()
|
||
|
end_stream = end_stream if end_stream else stream
|
||
|
start = torch.cuda.Event(enable_timing=True)
|
||
|
end = torch.cuda.Event(enable_timing=True)
|
||
|
stream.record_event(start)
|
||
|
try:
|
||
|
cpu_start = time.monotonic()
|
||
|
yield
|
||
|
finally:
|
||
|
cpu_end = time.monotonic()
|
||
|
end_stream.record_event(end)
|
||
|
end.synchronize()
|
||
|
cpu_time = (cpu_end - cpu_start) * 1000
|
||
|
gpu_time = start.elapsed_time(end)
|
||
|
msg = '{} {} cpu_time {:.2f} ms '.format(trace_name, name,
|
||
|
cpu_time)
|
||
|
msg += 'gpu_time {:.2f} ms stream {}'.format(gpu_time, stream)
|
||
|
print(msg, end_stream)
|
||
|
|
||
|
|
||
|
def benchmark_torch_function(iters, f, *args):
|
||
|
f(*args)
|
||
|
torch.cuda.synchronize()
|
||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||
|
start_event.record()
|
||
|
for i in range(iters):
|
||
|
f(*args)
|
||
|
end_event.record()
|
||
|
torch.cuda.synchronize()
|
||
|
return (start_event.elapsed_time(end_event) * 1.0e-3) / iters
|
||
|
|
||
|
|
||
|
def time_synchronized():
|
||
|
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||
|
return time.time()
|