[Fix] Raise AssertError when eval_res is a null dict

This commit is contained in:
zhouzaida 2021-10-13 11:42:56 +08:00 committed by Wenwei Zhang
parent 085e63629b
commit f2d11076eb
2 changed files with 19 additions and 2 deletions

View File

@ -358,6 +358,9 @@ class EvalHook(Hook):
"""
eval_res = self.dataloader.dataset.evaluate(
results, logger=runner.logger, **self.eval_kwargs)
assert eval_res, '`eval_res` should not be a null dict.'
for name, val in eval_res.items():
runner.log_buffer.output[name] = val
runner.log_buffer.ready = True

View File

@ -126,10 +126,24 @@ def test_eval_hook():
with pytest.raises(KeyError):
# rule must be in keys of rule_map
test_dataset = Model()
test_dataset = ExampleDataset()
data_loader = DataLoader(test_dataset)
EvalHook(data_loader, save_best='auto', rule='unsupport')
with pytest.raises(AssertionError):
# eval_res returned by `dataset.evaluate()` should not be a null dict
class _EvalDataset(ExampleDataset):
def evaluate(self, results, logger=None):
return {}
test_dataset = _EvalDataset()
data_loader = DataLoader(test_dataset)
eval_hook = EvalHook(data_loader)
runner = _build_epoch_runner()
runner.register_hook(eval_hook)
runner.run([data_loader], [('train', 1)], 1)
test_dataset = ExampleDataset()
loader = DataLoader(test_dataset)
model = Model()
@ -450,7 +464,7 @@ def test_logger(runner, by_epoch, eval_hook_priority):
path = osp.join(tmpdir, next(scandir(tmpdir, '.json')))
with open(path) as fr:
fr.readline() # skip first line which is hook_msg
fr.readline() # skip the first line which is `hook_msg`
train_log = json.loads(fr.readline())
assert train_log['mode'] == 'train' and 'time' in train_log
val_log = json.loads(fr.readline())