From 97f9efd825c43f5a1853241d3a2f43645f21fef5 Mon Sep 17 00:00:00 2001 From: Rui Xu Date: Wed, 17 Jun 2020 14:01:11 +0800 Subject: [PATCH] fix bug in using inner iter when by_epcoh==True (#346) * fix bug in using inner iter when by_epcoh==True * set default for log config * fix bug: remove optim in val, use iter in log when by_epoch is False --- mmcv/runner/hooks/logger/base.py | 12 ++++++++++-- mmcv/runner/hooks/logger/pavi.py | 6 ++++-- mmcv/runner/hooks/logger/tensorboard.py | 5 +++-- mmcv/runner/hooks/logger/text.py | 10 +++++++--- mmcv/runner/iter_based_runner.py | 6 ++++-- 5 files changed, 28 insertions(+), 11 deletions(-) diff --git a/mmcv/runner/hooks/logger/base.py b/mmcv/runner/hooks/logger/base.py index a518611de..e6665252f 100644 --- a/mmcv/runner/hooks/logger/base.py +++ b/mmcv/runner/hooks/logger/base.py @@ -12,14 +12,20 @@ class LoggerHook(Hook): ignore_last (bool): Ignore the log of last iterations in each epoch if less than `interval`. reset_flag (bool): Whether to clear the output buffer after logging. + by_epoch (bool): Whether EpochBasedRunner is used. """ __metaclass__ = ABCMeta - def __init__(self, interval=10, ignore_last=True, reset_flag=False): + def __init__(self, + interval=10, + ignore_last=True, + reset_flag=False, + by_epoch=True): self.interval = interval self.ignore_last = ignore_last self.reset_flag = reset_flag + self.by_epoch = by_epoch @abstractmethod def log(self, runner): @@ -35,7 +41,9 @@ class LoggerHook(Hook): runner.log_buffer.clear() # clear logs of last epoch def after_train_iter(self, runner): - if self.every_n_inner_iters(runner, self.interval): + if self.by_epoch and self.every_n_inner_iters(runner, self.interval): + runner.log_buffer.average(self.interval) + elif not self.by_epoch and self.every_n_iters(runner, self.interval): runner.log_buffer.average(self.interval) elif self.end_of_epoch(runner) and not self.ignore_last: # not precise but more stable diff --git a/mmcv/runner/hooks/logger/pavi.py b/mmcv/runner/hooks/logger/pavi.py index 09e201dcf..9eb7d2169 100644 --- a/mmcv/runner/hooks/logger/pavi.py +++ b/mmcv/runner/hooks/logger/pavi.py @@ -40,8 +40,10 @@ class PaviLoggerHook(LoggerHook): add_last_ckpt=False, interval=10, ignore_last=True, - reset_flag=True): - super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag) + reset_flag=True, + by_epoch=True): + super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag, + by_epoch) self.init_kwargs = init_kwargs self.add_graph = add_graph self.add_last_ckpt = add_last_ckpt diff --git a/mmcv/runner/hooks/logger/tensorboard.py b/mmcv/runner/hooks/logger/tensorboard.py index fd9ac198e..27416107d 100644 --- a/mmcv/runner/hooks/logger/tensorboard.py +++ b/mmcv/runner/hooks/logger/tensorboard.py @@ -14,9 +14,10 @@ class TensorboardLoggerHook(LoggerHook): log_dir=None, interval=10, ignore_last=True, - reset_flag=True): + reset_flag=True, + by_epoch=True): super(TensorboardLoggerHook, self).__init__(interval, ignore_last, - reset_flag) + reset_flag, by_epoch) self.log_dir = log_dir @master_only diff --git a/mmcv/runner/hooks/logger/text.py b/mmcv/runner/hooks/logger/text.py index b8c767033..866016470 100644 --- a/mmcv/runner/hooks/logger/text.py +++ b/mmcv/runner/hooks/logger/text.py @@ -35,7 +35,8 @@ class TextLoggerHook(LoggerHook): ignore_last=True, reset_flag=False, interval_exp_name=1000): - super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag) + super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag, + by_epoch) self.by_epoch = by_epoch self.time_sec_tot = 0 self.interval_exp_name = interval_exp_name @@ -61,7 +62,7 @@ class TextLoggerHook(LoggerHook): # print exp name for users to distinguish experiments # at every ``interval_exp_name`` iterations and the end of each epoch if runner.meta is not None and 'exp_name' in runner.meta: - if (self.every_n_inner_iters(runner, self.interval_exp_name)) or ( + if (self.every_n_iters(runner, self.interval_exp_name)) or ( self.by_epoch and self.end_of_epoch(runner)): exp_info = f'Exp name: {runner.meta["exp_name"]}' runner.logger.info(exp_info) @@ -144,7 +145,10 @@ class TextLoggerHook(LoggerHook): mode = 'train' if 'time' in runner.log_buffer.output else 'val' log_dict['mode'] = mode log_dict['epoch'] = runner.epoch + 1 - log_dict['iter'] = runner.inner_iter + 1 + if self.by_epoch: + log_dict['iter'] = runner.inner_iter + 1 + else: + log_dict['iter'] = runner.iter + 1 # only record lr of the first param group cur_lr = runner.current_lr() if isinstance(cur_lr, list): diff --git a/mmcv/runner/iter_based_runner.py b/mmcv/runner/iter_based_runner.py index 38018c62f..06922e7fc 100644 --- a/mmcv/runner/iter_based_runner.py +++ b/mmcv/runner/iter_based_runner.py @@ -65,11 +65,10 @@ class IterBasedRunner(BaseRunner): def val(self, data_loader, **kwargs): self.model.eval() self.mode = 'val' - self._inner_iter = 0 self.data_loader = data_loader self.call_hook('before_val_iter') data_batch = next(data_loader) - outputs = self.model.val_step(data_batch, self.optimizer, **kwargs) + outputs = self.model.val_step(data_batch, **kwargs) if not isinstance(outputs, dict): raise TypeError('model.val_step() must return a dict') if 'log_vars' in outputs: @@ -107,6 +106,7 @@ class IterBasedRunner(BaseRunner): while self.iter < max_iters: for i, flow in enumerate(workflow): + self._inner_iter = 0 mode, iters = flow if not isinstance(mode, str) or not hasattr(self, mode): raise ValueError( @@ -220,4 +220,6 @@ class IterBasedRunner(BaseRunner): self.register_optimizer_hook(optimizer_config) self.register_checkpoint_hook(checkpoint_config) self.register_hook(IterTimerHook()) + if log_config is not None: + log_config.setdefault('by_epoch', False) self.register_logger_hooks(log_config)