Fix unit tests (#449)

This commit is contained in:
Mashiro 2022-08-21 14:54:24 +08:00 committed by GitHub
parent 429bb27972
commit e907931fb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 5 additions and 5 deletions

View File

@ -9,7 +9,6 @@ from torch.utils.data import DataLoader
from mmengine.evaluator import Evaluator
from mmengine.registry import LOOPS
from mmengine.utils import is_list_of
from .amp import autocast
from .base_loop import BaseLoop
from .utils import calc_dynamic_intervals
@ -389,7 +388,7 @@ class TestLoop(BaseLoop):
fp16: bool = False):
super().__init__(runner, dataloader)
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
if isinstance(evaluator, dict) or isinstance(evaluator, list):
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else:
self.evaluator = evaluator # type: ignore

View File

@ -8,6 +8,7 @@ import torch
import torch.nn as nn
from torch.utils.data import Dataset
from mmengine.evaluator import Evaluator
from mmengine.hooks import EMAHook
from mmengine.model import BaseModel, ExponentialMovingAverage
from mmengine.optim import OptimWrapper
@ -81,7 +82,7 @@ class TestEMAHook(TestCase):
def test_ema_hook(self):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = ToyModel1().to(device)
evaluator = Mock()
evaluator = Evaluator([])
evaluator.evaluate = Mock(return_value=dict(acc=0.5))
runner = Runner(
model=model,

View File

@ -558,10 +558,10 @@ class TestRunner(TestCase):
param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]),
val_cfg=dict(),
val_dataloader=val_dataloader,
val_evaluator=ToyMetric1(),
val_evaluator=[ToyMetric1()],
test_cfg=dict(),
test_dataloader=test_dataloader,
test_evaluator=ToyMetric1(),
test_evaluator=[ToyMetric1()],
default_hooks=dict(param_scheduler=toy_hook),
custom_hooks=[toy_hook2],
experiment_name='test_init14')