From 3ea98248ba0d3528aecb7397522ab32b6bbe0ed7 Mon Sep 17 00:00:00 2001 From: thangvu Date: Mon, 10 Dec 2018 17:09:19 +0900 Subject: [PATCH] add moving average eta --- mmcv/runner/hooks/logger/text.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/mmcv/runner/hooks/logger/text.py b/mmcv/runner/hooks/logger/text.py index 7693bbb4e..8b50e8ef3 100644 --- a/mmcv/runner/hooks/logger/text.py +++ b/mmcv/runner/hooks/logger/text.py @@ -1,9 +1,14 @@ -from .base import LoggerHook import datetime +from .base import LoggerHook + class TextLoggerHook(LoggerHook): + def __init__(self, interval=10, ignore_last=True, reset_flag=False): + super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag) + self.time_sec_tot = 0 + def log(self, runner): if runner.mode == 'train': lr_str = ', '.join( @@ -15,10 +20,10 @@ class TextLoggerHook(LoggerHook): log_str = 'Epoch({}) [{}][{}]\t'.format(runner.mode, runner.epoch, runner.inner_iter + 1) if 'time' in runner.log_buffer.output: - tot_time_sec = runner.log_buffer.output['time'] + \ - runner.log_buffer.output['data_time'] - eta_sec = tot_time_sec * \ - (len(runner.data_loader) * runner.max_epochs - runner._iter) + self.time_sec_tot += (runner.log_buffer.output['time'] * + self.interval) + time_sec_avg = self.time_sec_tot / (runner.iter + 1) + eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) eta_str = str(datetime.timedelta(seconds=int(eta_sec))) log_str += ('eta: {}, '.format(eta_str)) log_str += (