mirror of https://github.com/open-mmlab/mmcv.git
print a warning information when eval_res is an empty dict
parent
f2d11076eb
commit
0633f91139
|
@ -271,7 +271,9 @@ class EvalHook(Hook):
|
|||
results = self.test_fn(runner.model, self.dataloader)
|
||||
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
|
||||
key_score = self.evaluate(runner, results)
|
||||
if self.save_best:
|
||||
# the key_score may be `None` so it needs to skip the action to save
|
||||
# the best checkpoint
|
||||
if self.save_best and key_score:
|
||||
self._save_ckpt(runner, key_score)
|
||||
|
||||
def _should_evaluate(self, runner):
|
||||
|
@ -359,13 +361,21 @@ class EvalHook(Hook):
|
|||
eval_res = self.dataloader.dataset.evaluate(
|
||||
results, logger=runner.logger, **self.eval_kwargs)
|
||||
|
||||
assert eval_res, '`eval_res` should not be a null dict.'
|
||||
|
||||
for name, val in eval_res.items():
|
||||
runner.log_buffer.output[name] = val
|
||||
runner.log_buffer.ready = True
|
||||
|
||||
if self.save_best is not None:
|
||||
# If the performance of model is pool, the `eval_res` may be an
|
||||
# empty dict and it will raise exception when `self.save_best` is
|
||||
# not None. More details at
|
||||
# https://github.com/open-mmlab/mmdetection/issues/6265.
|
||||
if not eval_res:
|
||||
warnings.warn(
|
||||
'Since `eval_res` is an empty dict, the behavior to save '
|
||||
'the best checkpoint will be skipped in this evaluation.')
|
||||
return None
|
||||
|
||||
if self.key_indicator == 'auto':
|
||||
# infer from eval_results
|
||||
self._init_rule(self.rule, list(eval_res.keys())[0])
|
||||
|
@ -493,6 +503,7 @@ class DistEvalHook(EvalHook):
|
|||
print('\n')
|
||||
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
|
||||
key_score = self.evaluate(runner, results)
|
||||
|
||||
if self.save_best:
|
||||
# the key_score may be `None` so it needs to skip the action to
|
||||
# save the best checkpoint
|
||||
if self.save_best and key_score:
|
||||
self._save_ckpt(runner, key_score)
|
||||
|
|
|
@ -130,8 +130,9 @@ def test_eval_hook():
|
|||
data_loader = DataLoader(test_dataset)
|
||||
EvalHook(data_loader, save_best='auto', rule='unsupport')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# eval_res returned by `dataset.evaluate()` should not be a null dict
|
||||
# if eval_res is an empty dict, print a warning information
|
||||
with pytest.warns(UserWarning) as record_warnings:
|
||||
|
||||
class _EvalDataset(ExampleDataset):
|
||||
|
||||
def evaluate(self, results, logger=None):
|
||||
|
@ -139,10 +140,20 @@ def test_eval_hook():
|
|||
|
||||
test_dataset = _EvalDataset()
|
||||
data_loader = DataLoader(test_dataset)
|
||||
eval_hook = EvalHook(data_loader)
|
||||
eval_hook = EvalHook(data_loader, save_best='auto')
|
||||
runner = _build_epoch_runner()
|
||||
runner.register_hook(eval_hook)
|
||||
runner.run([data_loader], [('train', 1)], 1)
|
||||
# Since there will be many warnings thrown, we just need to check if the
|
||||
# expected exceptions are thrown
|
||||
expected_message = ('Since `eval_res` is an empty dict, the behavior to '
|
||||
'save the best checkpoint will be skipped in this '
|
||||
'evaluation.')
|
||||
for warning in record_warnings:
|
||||
if str(warning.message) == expected_message:
|
||||
break
|
||||
else:
|
||||
assert False
|
||||
|
||||
test_dataset = ExampleDataset()
|
||||
loader = DataLoader(test_dataset)
|
||||
|
|
Loading…
Reference in New Issue