[Fix] Fix wrong iter number and progress number in the logging during val/test time (#914)

* Remove iter number in logging during val/test time

* Fix typo

* modified EvalHook for eval mode to print the correct iter number
pull/1024/head
Yezhen Cong 2021-05-13 20:32:33 +08:00 committed by GitHub
parent a1d3bf1c80
commit b36c4de157
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 2 deletions

View File

@ -73,7 +73,10 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
if rank == 0:
batch_size = len(result)
for _ in range(batch_size * world_size):
batch_size_all = batch_size * world_size
if batch_size_all + prog_bar.completed > len(dataset):
batch_size_all = len(dataset) - prog_bar.completed
for _ in range(batch_size_all):
prog_bar.update()
# collect results from all ranks

View File

@ -180,6 +180,7 @@ class EvalHook(Hook):
from mmcv.engine import single_gpu_test
results = single_gpu_test(runner.model, self.dataloader)
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
self._save_ckpt(runner, key_score)
@ -371,6 +372,7 @@ class DistEvalHook(EvalHook):
gpu_collect=self.gpu_collect)
if runner.rank == 0:
print('\n')
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:

View File

@ -99,6 +99,10 @@ class TextLoggerHook(LoggerHook):
if torch.cuda.is_available():
log_str += f'memory: {log_dict["memory"]}, '
else:
# val/test time
# here 1000 is the length of the val dataloader
# by epoch: Epoch[val] [4][1000]
# by iter: Iter[val] [1000]
if self.by_epoch:
log_str = f'Epoch({log_dict["mode"]}) ' \
f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t'
@ -141,10 +145,16 @@ class TextLoggerHook(LoggerHook):
return items
def log(self, runner):
if 'eval_iter_num' in runner.log_buffer.output:
# this doesn't modify runner.iter and is regardless of by_epoch
cur_iter = runner.log_buffer.output.pop('eval_iter_num')
else:
cur_iter = self.get_iter(runner, inner_iter=True)
log_dict = OrderedDict(
mode=self.get_mode(runner),
epoch=self.get_epoch(runner),
iter=self.get_iter(runner, inner_iter=True))
iter=cur_iter)
# only record lr of the first param group
cur_lr = runner.current_lr()