mirror of https://github.com/open-mmlab/mmcv.git
fix bug in using inner iter when by_epcoh==True (#346)
* fix bug in using inner iter when by_epcoh==True * set default for log config * fix bug: remove optim in val, use iter in log when by_epoch is Falsepull/352/head
parent
630b747cb1
commit
97f9efd825
|
@ -12,14 +12,20 @@ class LoggerHook(Hook):
|
|||
ignore_last (bool): Ignore the log of last iterations in each epoch
|
||||
if less than `interval`.
|
||||
reset_flag (bool): Whether to clear the output buffer after logging.
|
||||
by_epoch (bool): Whether EpochBasedRunner is used.
|
||||
"""
|
||||
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
def __init__(self, interval=10, ignore_last=True, reset_flag=False):
|
||||
def __init__(self,
|
||||
interval=10,
|
||||
ignore_last=True,
|
||||
reset_flag=False,
|
||||
by_epoch=True):
|
||||
self.interval = interval
|
||||
self.ignore_last = ignore_last
|
||||
self.reset_flag = reset_flag
|
||||
self.by_epoch = by_epoch
|
||||
|
||||
@abstractmethod
|
||||
def log(self, runner):
|
||||
|
@ -35,7 +41,9 @@ class LoggerHook(Hook):
|
|||
runner.log_buffer.clear() # clear logs of last epoch
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
if self.every_n_inner_iters(runner, self.interval):
|
||||
if self.by_epoch and self.every_n_inner_iters(runner, self.interval):
|
||||
runner.log_buffer.average(self.interval)
|
||||
elif not self.by_epoch and self.every_n_iters(runner, self.interval):
|
||||
runner.log_buffer.average(self.interval)
|
||||
elif self.end_of_epoch(runner) and not self.ignore_last:
|
||||
# not precise but more stable
|
||||
|
|
|
@ -40,8 +40,10 @@ class PaviLoggerHook(LoggerHook):
|
|||
add_last_ckpt=False,
|
||||
interval=10,
|
||||
ignore_last=True,
|
||||
reset_flag=True):
|
||||
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag)
|
||||
reset_flag=True,
|
||||
by_epoch=True):
|
||||
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag,
|
||||
by_epoch)
|
||||
self.init_kwargs = init_kwargs
|
||||
self.add_graph = add_graph
|
||||
self.add_last_ckpt = add_last_ckpt
|
||||
|
|
|
@ -14,9 +14,10 @@ class TensorboardLoggerHook(LoggerHook):
|
|||
log_dir=None,
|
||||
interval=10,
|
||||
ignore_last=True,
|
||||
reset_flag=True):
|
||||
reset_flag=True,
|
||||
by_epoch=True):
|
||||
super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
|
||||
reset_flag)
|
||||
reset_flag, by_epoch)
|
||||
self.log_dir = log_dir
|
||||
|
||||
@master_only
|
||||
|
|
|
@ -35,7 +35,8 @@ class TextLoggerHook(LoggerHook):
|
|||
ignore_last=True,
|
||||
reset_flag=False,
|
||||
interval_exp_name=1000):
|
||||
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag)
|
||||
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
|
||||
by_epoch)
|
||||
self.by_epoch = by_epoch
|
||||
self.time_sec_tot = 0
|
||||
self.interval_exp_name = interval_exp_name
|
||||
|
@ -61,7 +62,7 @@ class TextLoggerHook(LoggerHook):
|
|||
# print exp name for users to distinguish experiments
|
||||
# at every ``interval_exp_name`` iterations and the end of each epoch
|
||||
if runner.meta is not None and 'exp_name' in runner.meta:
|
||||
if (self.every_n_inner_iters(runner, self.interval_exp_name)) or (
|
||||
if (self.every_n_iters(runner, self.interval_exp_name)) or (
|
||||
self.by_epoch and self.end_of_epoch(runner)):
|
||||
exp_info = f'Exp name: {runner.meta["exp_name"]}'
|
||||
runner.logger.info(exp_info)
|
||||
|
@ -144,7 +145,10 @@ class TextLoggerHook(LoggerHook):
|
|||
mode = 'train' if 'time' in runner.log_buffer.output else 'val'
|
||||
log_dict['mode'] = mode
|
||||
log_dict['epoch'] = runner.epoch + 1
|
||||
log_dict['iter'] = runner.inner_iter + 1
|
||||
if self.by_epoch:
|
||||
log_dict['iter'] = runner.inner_iter + 1
|
||||
else:
|
||||
log_dict['iter'] = runner.iter + 1
|
||||
# only record lr of the first param group
|
||||
cur_lr = runner.current_lr()
|
||||
if isinstance(cur_lr, list):
|
||||
|
|
|
@ -65,11 +65,10 @@ class IterBasedRunner(BaseRunner):
|
|||
def val(self, data_loader, **kwargs):
|
||||
self.model.eval()
|
||||
self.mode = 'val'
|
||||
self._inner_iter = 0
|
||||
self.data_loader = data_loader
|
||||
self.call_hook('before_val_iter')
|
||||
data_batch = next(data_loader)
|
||||
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
|
||||
outputs = self.model.val_step(data_batch, **kwargs)
|
||||
if not isinstance(outputs, dict):
|
||||
raise TypeError('model.val_step() must return a dict')
|
||||
if 'log_vars' in outputs:
|
||||
|
@ -107,6 +106,7 @@ class IterBasedRunner(BaseRunner):
|
|||
|
||||
while self.iter < max_iters:
|
||||
for i, flow in enumerate(workflow):
|
||||
self._inner_iter = 0
|
||||
mode, iters = flow
|
||||
if not isinstance(mode, str) or not hasattr(self, mode):
|
||||
raise ValueError(
|
||||
|
@ -220,4 +220,6 @@ class IterBasedRunner(BaseRunner):
|
|||
self.register_optimizer_hook(optimizer_config)
|
||||
self.register_checkpoint_hook(checkpoint_config)
|
||||
self.register_hook(IterTimerHook())
|
||||
if log_config is not None:
|
||||
log_config.setdefault('by_epoch', False)
|
||||
self.register_logger_hooks(log_config)
|
||||
|
|
Loading…
Reference in New Issue