mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
[Fix] Raise AssertError when eval_res is a null dict
This commit is contained in:
parent
085e63629b
commit
f2d11076eb
@ -358,6 +358,9 @@ class EvalHook(Hook):
|
|||||||
"""
|
"""
|
||||||
eval_res = self.dataloader.dataset.evaluate(
|
eval_res = self.dataloader.dataset.evaluate(
|
||||||
results, logger=runner.logger, **self.eval_kwargs)
|
results, logger=runner.logger, **self.eval_kwargs)
|
||||||
|
|
||||||
|
assert eval_res, '`eval_res` should not be a null dict.'
|
||||||
|
|
||||||
for name, val in eval_res.items():
|
for name, val in eval_res.items():
|
||||||
runner.log_buffer.output[name] = val
|
runner.log_buffer.output[name] = val
|
||||||
runner.log_buffer.ready = True
|
runner.log_buffer.ready = True
|
||||||
|
@ -126,10 +126,24 @@ def test_eval_hook():
|
|||||||
|
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
# rule must be in keys of rule_map
|
# rule must be in keys of rule_map
|
||||||
test_dataset = Model()
|
test_dataset = ExampleDataset()
|
||||||
data_loader = DataLoader(test_dataset)
|
data_loader = DataLoader(test_dataset)
|
||||||
EvalHook(data_loader, save_best='auto', rule='unsupport')
|
EvalHook(data_loader, save_best='auto', rule='unsupport')
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# eval_res returned by `dataset.evaluate()` should not be a null dict
|
||||||
|
class _EvalDataset(ExampleDataset):
|
||||||
|
|
||||||
|
def evaluate(self, results, logger=None):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
test_dataset = _EvalDataset()
|
||||||
|
data_loader = DataLoader(test_dataset)
|
||||||
|
eval_hook = EvalHook(data_loader)
|
||||||
|
runner = _build_epoch_runner()
|
||||||
|
runner.register_hook(eval_hook)
|
||||||
|
runner.run([data_loader], [('train', 1)], 1)
|
||||||
|
|
||||||
test_dataset = ExampleDataset()
|
test_dataset = ExampleDataset()
|
||||||
loader = DataLoader(test_dataset)
|
loader = DataLoader(test_dataset)
|
||||||
model = Model()
|
model = Model()
|
||||||
@ -450,7 +464,7 @@ def test_logger(runner, by_epoch, eval_hook_priority):
|
|||||||
|
|
||||||
path = osp.join(tmpdir, next(scandir(tmpdir, '.json')))
|
path = osp.join(tmpdir, next(scandir(tmpdir, '.json')))
|
||||||
with open(path) as fr:
|
with open(path) as fr:
|
||||||
fr.readline() # skip first line which is hook_msg
|
fr.readline() # skip the first line which is `hook_msg`
|
||||||
train_log = json.loads(fr.readline())
|
train_log = json.loads(fr.readline())
|
||||||
assert train_log['mode'] == 'train' and 'time' in train_log
|
assert train_log['mode'] == 'train' and 'time' in train_log
|
||||||
val_log = json.loads(fr.readline())
|
val_log = json.loads(fr.readline())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user