[Fix] Fix error format of log message (#508)
* Fix error format of log message * Fix unit test * remove unnecessary commentpull/634/head
parent
abe56651db
commit
aaba1d8871
|
@ -19,6 +19,7 @@ class IterTimerHook(Hook):
|
|||
|
||||
def __init__(self):
|
||||
self.time_sec_tot = 0
|
||||
self.time_sec_test_val = 0
|
||||
self.start_iter = 0
|
||||
|
||||
def before_train(self, runner) -> None:
|
||||
|
@ -41,6 +42,9 @@ class IterTimerHook(Hook):
|
|||
"""
|
||||
self.t = time.time()
|
||||
|
||||
def _after_epoch(self, runner, mode: str = 'train') -> None:
|
||||
self.time_sec_test_val = 0
|
||||
|
||||
def _before_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
|
@ -82,26 +86,22 @@ class IterTimerHook(Hook):
|
|||
message_hub = runner.message_hub
|
||||
message_hub.update_scalar(f'{mode}/time', time.time() - self.t)
|
||||
self.t = time.time()
|
||||
window_size = runner.log_processor.window_size
|
||||
# Calculate eta every `window_size` iterations. Since test and val
|
||||
# loop will not update runner.iter, use `every_n_innter_iters`to check
|
||||
# the interval.
|
||||
if self.every_n_inner_iters(batch_idx, window_size):
|
||||
iter_time = message_hub.get_scalar(f'{mode}/time').mean(
|
||||
window_size)
|
||||
if mode == 'train':
|
||||
self.time_sec_tot += iter_time * window_size
|
||||
# Calculate average iterative time.
|
||||
time_sec_avg = self.time_sec_tot / (
|
||||
runner.iter - self.start_iter + 1)
|
||||
# Calculate eta.
|
||||
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
|
||||
runner.message_hub.update_info('eta', eta_sec)
|
||||
iter_time = message_hub.get_scalar(f'{mode}/time')
|
||||
if mode == 'train':
|
||||
self.time_sec_tot += iter_time.current()
|
||||
# Calculate average iterative time.
|
||||
time_sec_avg = self.time_sec_tot / (
|
||||
runner.iter - self.start_iter + 1)
|
||||
# Calculate eta.
|
||||
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
|
||||
runner.message_hub.update_info('eta', eta_sec)
|
||||
else:
|
||||
if mode == 'val':
|
||||
cur_dataloader = runner.val_dataloader
|
||||
else:
|
||||
if mode == 'val':
|
||||
cur_dataloader = runner.val_dataloader
|
||||
else:
|
||||
cur_dataloader = runner.test_dataloader
|
||||
cur_dataloader = runner.test_dataloader
|
||||
|
||||
eta_sec = iter_time * (len(cur_dataloader) - batch_idx - 1)
|
||||
runner.message_hub.update_info('eta', eta_sec)
|
||||
self.time_sec_test_val += iter_time.current()
|
||||
time_sec_avg = self.time_sec_test_val / (batch_idx + 1)
|
||||
eta_sec = time_sec_avg * (len(cur_dataloader) - batch_idx - 1)
|
||||
runner.message_hub.update_info('eta', eta_sec)
|
||||
|
|
|
@ -53,18 +53,30 @@ class TestIterTimerHook(TestCase):
|
|||
runner.iter = 0
|
||||
runner.test_dataloader = [0] * 20
|
||||
runner.val_dataloader = [0] * 20
|
||||
self.hook._before_epoch(runner)
|
||||
self.hook.before_run(runner)
|
||||
self.hook._after_iter(runner, batch_idx=1)
|
||||
runner.message_hub.update_scalar.assert_called()
|
||||
runner.message_hub.get_log.assert_not_called()
|
||||
runner.message_hub.update_info.assert_not_called()
|
||||
runner.message_hub = MessageHub.get_instance('test_iter_timer_hook')
|
||||
runner.iter = 9
|
||||
|
||||
self.hook.before_run(runner)
|
||||
self.hook._before_epoch(runner)
|
||||
# eta = (100 - 10) / 1
|
||||
self.hook._after_iter(runner, batch_idx=89)
|
||||
for _ in range(10):
|
||||
self.hook._after_iter(runner, 1)
|
||||
runner.iter += 1
|
||||
assert runner.message_hub.get_info('eta') == 90
|
||||
self.hook._after_iter(runner, batch_idx=9, mode='val')
|
||||
|
||||
for i in range(10):
|
||||
self.hook._after_iter(runner, batch_idx=i, mode='val')
|
||||
assert runner.message_hub.get_info('eta') == 10
|
||||
self.hook._after_iter(runner, batch_idx=19, mode='test')
|
||||
|
||||
for i in range(11, 20):
|
||||
self.hook._after_iter(runner, batch_idx=i, mode='val')
|
||||
assert runner.message_hub.get_info('eta') == 0
|
||||
|
||||
self.hook.after_val_epoch(runner)
|
||||
|
||||
for i in range(10):
|
||||
self.hook._after_iter(runner, batch_idx=i, mode='test')
|
||||
assert runner.message_hub.get_info('eta') == 10
|
||||
|
||||
for i in range(11, 20):
|
||||
self.hook._after_iter(runner, batch_idx=i, mode='test')
|
||||
assert runner.message_hub.get_info('eta') == 0
|
||||
|
|
Loading…
Reference in New Issue