mirror of
https://github.com/open-mmlab/mmfewshot.git
synced 2025-06-03 14:49:43 +08:00
Merge 624bce637e429cf73b72440ab07631cdf67a1b23 into af4fad50b0e00946467244dc8bce4d232d2b79c2
This commit is contained in:
commit
d19d83a799
@ -86,6 +86,10 @@ class MetaTestEvalHook(Hook):
|
||||
warnings.warn('runner.meta is None. Creating an empty one.')
|
||||
runner.meta = dict()
|
||||
runner.meta.setdefault('hook_msgs', dict())
|
||||
if runner.meta['hook_msgs'].get('best_score', False):
|
||||
self.best_score = runner.meta['hook_msgs']['best_score']
|
||||
runner.logger.info(
|
||||
f'Previous best score is: {self.best_score}.')
|
||||
self.best_ckpt_path = runner.meta['hook_msgs'].get(
|
||||
'best_ckpt', None)
|
||||
|
||||
|
@ -171,3 +171,48 @@ def test_epoch_eval_hook():
|
||||
max_epochs=1)
|
||||
runner.register_hook(eval_hook)
|
||||
runner.run([loader], [('train', 1)], 1)
|
||||
|
||||
|
||||
def test_resume_eval_hook():
|
||||
test_set_loader = DataLoader(
|
||||
toy_meta_test_dataset(),
|
||||
batch_size=1,
|
||||
sampler=None,
|
||||
num_workers=0,
|
||||
shuffle=False)
|
||||
query_loader = DataLoader(
|
||||
toy_meta_test_dataset().query(),
|
||||
batch_size=1,
|
||||
sampler=None,
|
||||
num_workers=0,
|
||||
shuffle=False)
|
||||
support_loader = DataLoader(
|
||||
toy_meta_test_dataset().support(),
|
||||
batch_size=1,
|
||||
sampler=None,
|
||||
num_workers=0,
|
||||
shuffle=False)
|
||||
model = ExampleModel()
|
||||
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
||||
dict(params=model.parameters()))
|
||||
test_dataset = ExampleDataset()
|
||||
loader = DataLoader(test_dataset, batch_size=1)
|
||||
# test EvalHook
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
eval_hook = MetaTestEvalHook(
|
||||
support_loader,
|
||||
query_loader,
|
||||
test_set_loader,
|
||||
num_test_tasks=10,
|
||||
meta_test_cfg=dict(support={}, query={}))
|
||||
runner = mmcv.runner.EpochBasedRunner(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
work_dir=tmpdir,
|
||||
logger=logging.getLogger(),
|
||||
max_epochs=1)
|
||||
runner.register_hook(eval_hook)
|
||||
runner.meta = {'hook_msgs': {'best_score': 99.0}}
|
||||
runner.run([loader], [('train', 1)], 1)
|
||||
assert eval_hook.best_score == 99.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user