EasyCV/easycv/hooks/throughput_hook.py

59 lines
2.0 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import time
from mmcv.runner.hooks import Hook
from torch import distributed as dist
from easycv.hooks.registry import HOOKS
from easycv.utils.dist_utils import get_dist_info
@HOOKS.register_module()
class ThroughputHook(Hook):
"""Count the throughput per second of all steps in the history.
`warmup_iters` can be set to skip the calculation of the first few steps,
if the initialization of the first few steps is slow.
"""
def __init__(self, warmup_iters=0, **kwargs) -> None:
self.warmup_iters = warmup_iters
self._iter_count = 0
self._start = False
def _reset(self):
self._start_time = time.time()
self._iter_count = 0
self._start = False
def before_train_epoch(self, runner):
"""reset per epoch
"""
self._reset()
def before_train_iter(self, runner):
if not self._start and self._iter_count == self.warmup_iters:
self._start_time = time.time()
self._start = True
def after_train_iter(self, runner):
self._iter_count += 1
key = 'avg throughput'
batch_size = runner.data_loader.batch_size
_, world_size = get_dist_info()
total_batch_size = batch_size * world_size
# The LoggerHook will average the log_buffer of the latest interval,
# but we want to use the total time to calculate the throughput,
# so we delete the historical buffers of the key to ensure that
# the value printed each time is the total historical average
if key in runner.log_buffer.val_history:
runner.log_buffer.val_history[key] = []
runner.log_buffer.n_history[key] = []
total_time = time.time() - self._start_time
throughput = max(0,
(self._iter_count -
self.warmup_iters)) * total_batch_size / total_time
runner.log_buffer.update({key: throughput}, count=self._iter_count)