mirror of https://github.com/open-mmlab/mmcv.git
Fix the time estimation when resuming from a checkpoint (#37)
* fix the time estimation when resuming from a checkpoint * fix the time estimation when resuming from a checkpointpull/38/head
parent
1d6e91b1b0
commit
8b4417c19a
|
@ -9,6 +9,10 @@ class TextLoggerHook(LoggerHook):
|
|||
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag)
|
||||
self.time_sec_tot = 0
|
||||
|
||||
def before_run(self, runner):
|
||||
super(TextLoggerHook, self).before_run(runner)
|
||||
self.start_iter = runner.iter
|
||||
|
||||
def log(self, runner):
|
||||
if runner.mode == 'train':
|
||||
lr_str = ', '.join(
|
||||
|
@ -20,9 +24,10 @@ class TextLoggerHook(LoggerHook):
|
|||
log_str = 'Epoch({}) [{}][{}]\t'.format(runner.mode, runner.epoch,
|
||||
runner.inner_iter + 1)
|
||||
if 'time' in runner.log_buffer.output:
|
||||
self.time_sec_tot += (runner.log_buffer.output['time'] *
|
||||
self.interval)
|
||||
time_sec_avg = self.time_sec_tot / (runner.iter + 1)
|
||||
self.time_sec_tot += (
|
||||
runner.log_buffer.output['time'] * self.interval)
|
||||
time_sec_avg = self.time_sec_tot / (
|
||||
runner.iter - self.start_iter + 1)
|
||||
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
|
||||
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
||||
log_str += 'eta: {}, '.format(eta_str)
|
||||
|
|
Loading…
Reference in New Issue