print a warning information when eval_res is an empty dict

pull/1489/head
zhouzaida 2021-11-06 00:24:20 +08:00 committed by Wenwei Zhang
parent f2d11076eb
commit 0633f91139
2 changed files with 30 additions and 8 deletions

View File

@ -271,7 +271,9 @@ class EvalHook(Hook):
results = self.test_fn(runner.model, self.dataloader)
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
# the key_score may be `None` so it needs to skip the action to save
# the best checkpoint
if self.save_best and key_score:
self._save_ckpt(runner, key_score)
def _should_evaluate(self, runner):
@ -359,13 +361,21 @@ class EvalHook(Hook):
eval_res = self.dataloader.dataset.evaluate(
results, logger=runner.logger, **self.eval_kwargs)
assert eval_res, '`eval_res` should not be a null dict.'
for name, val in eval_res.items():
runner.log_buffer.output[name] = val
runner.log_buffer.ready = True
if self.save_best is not None:
# If the performance of model is pool, the `eval_res` may be an
# empty dict and it will raise exception when `self.save_best` is
# not None. More details at
# https://github.com/open-mmlab/mmdetection/issues/6265.
if not eval_res:
warnings.warn(
'Since `eval_res` is an empty dict, the behavior to save '
'the best checkpoint will be skipped in this evaluation.')
return None
if self.key_indicator == 'auto':
# infer from eval_results
self._init_rule(self.rule, list(eval_res.keys())[0])
@ -493,6 +503,7 @@ class DistEvalHook(EvalHook):
print('\n')
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
# the key_score may be `None` so it needs to skip the action to
# save the best checkpoint
if self.save_best and key_score:
self._save_ckpt(runner, key_score)

View File

@ -130,8 +130,9 @@ def test_eval_hook():
data_loader = DataLoader(test_dataset)
EvalHook(data_loader, save_best='auto', rule='unsupport')
with pytest.raises(AssertionError):
# eval_res returned by `dataset.evaluate()` should not be a null dict
# if eval_res is an empty dict, print a warning information
with pytest.warns(UserWarning) as record_warnings:
class _EvalDataset(ExampleDataset):
def evaluate(self, results, logger=None):
@ -139,10 +140,20 @@ def test_eval_hook():
test_dataset = _EvalDataset()
data_loader = DataLoader(test_dataset)
eval_hook = EvalHook(data_loader)
eval_hook = EvalHook(data_loader, save_best='auto')
runner = _build_epoch_runner()
runner.register_hook(eval_hook)
runner.run([data_loader], [('train', 1)], 1)
# Since there will be many warnings thrown, we just need to check if the
# expected exceptions are thrown
expected_message = ('Since `eval_res` is an empty dict, the behavior to '
'save the best checkpoint will be skipped in this '
'evaluation.')
for warning in record_warnings:
if str(warning.message) == expected_message:
break
else:
assert False
test_dataset = ExampleDataset()
loader = DataLoader(test_dataset)