fix bug in using inner iter when by_epcoh==True (#346)

* fix bug in using inner iter when by_epcoh==True

* set default for log config

* fix bug: remove optim in val, use iter in log when by_epoch is False
pull/352/head
Rui Xu 2020-06-17 14:01:11 +08:00 committed by GitHub
parent 630b747cb1
commit 97f9efd825
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 28 additions and 11 deletions

View File

@ -12,14 +12,20 @@ class LoggerHook(Hook):
ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`.
reset_flag (bool): Whether to clear the output buffer after logging.
by_epoch (bool): Whether EpochBasedRunner is used.
"""
__metaclass__ = ABCMeta
def __init__(self, interval=10, ignore_last=True, reset_flag=False):
def __init__(self,
interval=10,
ignore_last=True,
reset_flag=False,
by_epoch=True):
self.interval = interval
self.ignore_last = ignore_last
self.reset_flag = reset_flag
self.by_epoch = by_epoch
@abstractmethod
def log(self, runner):
@ -35,7 +41,9 @@ class LoggerHook(Hook):
runner.log_buffer.clear() # clear logs of last epoch
def after_train_iter(self, runner):
if self.every_n_inner_iters(runner, self.interval):
if self.by_epoch and self.every_n_inner_iters(runner, self.interval):
runner.log_buffer.average(self.interval)
elif not self.by_epoch and self.every_n_iters(runner, self.interval):
runner.log_buffer.average(self.interval)
elif self.end_of_epoch(runner) and not self.ignore_last:
# not precise but more stable

View File

@ -40,8 +40,10 @@ class PaviLoggerHook(LoggerHook):
add_last_ckpt=False,
interval=10,
ignore_last=True,
reset_flag=True):
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag)
reset_flag=True,
by_epoch=True):
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag,
by_epoch)
self.init_kwargs = init_kwargs
self.add_graph = add_graph
self.add_last_ckpt = add_last_ckpt

View File

@ -14,9 +14,10 @@ class TensorboardLoggerHook(LoggerHook):
log_dir=None,
interval=10,
ignore_last=True,
reset_flag=True):
reset_flag=True,
by_epoch=True):
super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
reset_flag)
reset_flag, by_epoch)
self.log_dir = log_dir
@master_only

View File

@ -35,7 +35,8 @@ class TextLoggerHook(LoggerHook):
ignore_last=True,
reset_flag=False,
interval_exp_name=1000):
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag)
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
by_epoch)
self.by_epoch = by_epoch
self.time_sec_tot = 0
self.interval_exp_name = interval_exp_name
@ -61,7 +62,7 @@ class TextLoggerHook(LoggerHook):
# print exp name for users to distinguish experiments
# at every ``interval_exp_name`` iterations and the end of each epoch
if runner.meta is not None and 'exp_name' in runner.meta:
if (self.every_n_inner_iters(runner, self.interval_exp_name)) or (
if (self.every_n_iters(runner, self.interval_exp_name)) or (
self.by_epoch and self.end_of_epoch(runner)):
exp_info = f'Exp name: {runner.meta["exp_name"]}'
runner.logger.info(exp_info)
@ -144,7 +145,10 @@ class TextLoggerHook(LoggerHook):
mode = 'train' if 'time' in runner.log_buffer.output else 'val'
log_dict['mode'] = mode
log_dict['epoch'] = runner.epoch + 1
log_dict['iter'] = runner.inner_iter + 1
if self.by_epoch:
log_dict['iter'] = runner.inner_iter + 1
else:
log_dict['iter'] = runner.iter + 1
# only record lr of the first param group
cur_lr = runner.current_lr()
if isinstance(cur_lr, list):

View File

@ -65,11 +65,10 @@ class IterBasedRunner(BaseRunner):
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self._inner_iter = 0
self.data_loader = data_loader
self.call_hook('before_val_iter')
data_batch = next(data_loader)
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
outputs = self.model.val_step(data_batch, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('model.val_step() must return a dict')
if 'log_vars' in outputs:
@ -107,6 +106,7 @@ class IterBasedRunner(BaseRunner):
while self.iter < max_iters:
for i, flow in enumerate(workflow):
self._inner_iter = 0
mode, iters = flow
if not isinstance(mode, str) or not hasattr(self, mode):
raise ValueError(
@ -220,4 +220,6 @@ class IterBasedRunner(BaseRunner):
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_hook(IterTimerHook())
if log_config is not None:
log_config.setdefault('by_epoch', False)
self.register_logger_hooks(log_config)