[Fix] Fix error format of log message (#508)

* Fix error format of log message

* Fix unit test

* remove unnecessary comment
pull/634/head
Mashiro 2022-10-18 18:04:15 +08:00 committed by GitHub
parent abe56651db
commit aaba1d8871
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 31 deletions

View File

@ -19,6 +19,7 @@ class IterTimerHook(Hook):
def __init__(self):
self.time_sec_tot = 0
self.time_sec_test_val = 0
self.start_iter = 0
def before_train(self, runner) -> None:
@ -41,6 +42,9 @@ class IterTimerHook(Hook):
"""
self.t = time.time()
def _after_epoch(self, runner, mode: str = 'train') -> None:
self.time_sec_test_val = 0
def _before_iter(self,
runner,
batch_idx: int,
@ -82,26 +86,22 @@ class IterTimerHook(Hook):
message_hub = runner.message_hub
message_hub.update_scalar(f'{mode}/time', time.time() - self.t)
self.t = time.time()
window_size = runner.log_processor.window_size
# Calculate eta every `window_size` iterations. Since test and val
# loop will not update runner.iter, use `every_n_innter_iters`to check
# the interval.
if self.every_n_inner_iters(batch_idx, window_size):
iter_time = message_hub.get_scalar(f'{mode}/time').mean(
window_size)
if mode == 'train':
self.time_sec_tot += iter_time * window_size
# Calculate average iterative time.
time_sec_avg = self.time_sec_tot / (
runner.iter - self.start_iter + 1)
# Calculate eta.
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
runner.message_hub.update_info('eta', eta_sec)
iter_time = message_hub.get_scalar(f'{mode}/time')
if mode == 'train':
self.time_sec_tot += iter_time.current()
# Calculate average iterative time.
time_sec_avg = self.time_sec_tot / (
runner.iter - self.start_iter + 1)
# Calculate eta.
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
runner.message_hub.update_info('eta', eta_sec)
else:
if mode == 'val':
cur_dataloader = runner.val_dataloader
else:
if mode == 'val':
cur_dataloader = runner.val_dataloader
else:
cur_dataloader = runner.test_dataloader
cur_dataloader = runner.test_dataloader
eta_sec = iter_time * (len(cur_dataloader) - batch_idx - 1)
runner.message_hub.update_info('eta', eta_sec)
self.time_sec_test_val += iter_time.current()
time_sec_avg = self.time_sec_test_val / (batch_idx + 1)
eta_sec = time_sec_avg * (len(cur_dataloader) - batch_idx - 1)
runner.message_hub.update_info('eta', eta_sec)

View File

@ -53,18 +53,30 @@ class TestIterTimerHook(TestCase):
runner.iter = 0
runner.test_dataloader = [0] * 20
runner.val_dataloader = [0] * 20
self.hook._before_epoch(runner)
self.hook.before_run(runner)
self.hook._after_iter(runner, batch_idx=1)
runner.message_hub.update_scalar.assert_called()
runner.message_hub.get_log.assert_not_called()
runner.message_hub.update_info.assert_not_called()
runner.message_hub = MessageHub.get_instance('test_iter_timer_hook')
runner.iter = 9
self.hook.before_run(runner)
self.hook._before_epoch(runner)
# eta = (100 - 10) / 1
self.hook._after_iter(runner, batch_idx=89)
for _ in range(10):
self.hook._after_iter(runner, 1)
runner.iter += 1
assert runner.message_hub.get_info('eta') == 90
self.hook._after_iter(runner, batch_idx=9, mode='val')
for i in range(10):
self.hook._after_iter(runner, batch_idx=i, mode='val')
assert runner.message_hub.get_info('eta') == 10
self.hook._after_iter(runner, batch_idx=19, mode='test')
for i in range(11, 20):
self.hook._after_iter(runner, batch_idx=i, mode='val')
assert runner.message_hub.get_info('eta') == 0
self.hook.after_val_epoch(runner)
for i in range(10):
self.hook._after_iter(runner, batch_idx=i, mode='test')
assert runner.message_hub.get_info('eta') == 10
for i in range(11, 20):
self.hook._after_iter(runner, batch_idx=i, mode='test')
assert runner.message_hub.get_info('eta') == 0