add moving average eta

pull/26/head
thangvu 2018-12-10 17:09:19 +09:00
parent 578bbe6241
commit 3ea98248ba
1 changed files with 10 additions and 5 deletions

View File

@ -1,9 +1,14 @@
from .base import LoggerHook
import datetime import datetime
from .base import LoggerHook
class TextLoggerHook(LoggerHook): class TextLoggerHook(LoggerHook):
def __init__(self, interval=10, ignore_last=True, reset_flag=False):
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag)
self.time_sec_tot = 0
def log(self, runner): def log(self, runner):
if runner.mode == 'train': if runner.mode == 'train':
lr_str = ', '.join( lr_str = ', '.join(
@ -15,10 +20,10 @@ class TextLoggerHook(LoggerHook):
log_str = 'Epoch({}) [{}][{}]\t'.format(runner.mode, runner.epoch, log_str = 'Epoch({}) [{}][{}]\t'.format(runner.mode, runner.epoch,
runner.inner_iter + 1) runner.inner_iter + 1)
if 'time' in runner.log_buffer.output: if 'time' in runner.log_buffer.output:
tot_time_sec = runner.log_buffer.output['time'] + \ self.time_sec_tot += (runner.log_buffer.output['time'] *
runner.log_buffer.output['data_time'] self.interval)
eta_sec = tot_time_sec * \ time_sec_avg = self.time_sec_tot / (runner.iter + 1)
(len(runner.data_loader) * runner.max_epochs - runner._iter) eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_sec))) eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
log_str += ('eta: {}, '.format(eta_str)) log_str += ('eta: {}, '.format(eta_str))
log_str += ( log_str += (