mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Fix unit tests (#449)
This commit is contained in:
parent
429bb27972
commit
e907931fb8
@ -9,7 +9,6 @@ from torch.utils.data import DataLoader
|
||||
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.registry import LOOPS
|
||||
from mmengine.utils import is_list_of
|
||||
from .amp import autocast
|
||||
from .base_loop import BaseLoop
|
||||
from .utils import calc_dynamic_intervals
|
||||
@ -389,7 +388,7 @@ class TestLoop(BaseLoop):
|
||||
fp16: bool = False):
|
||||
super().__init__(runner, dataloader)
|
||||
|
||||
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
||||
if isinstance(evaluator, dict) or isinstance(evaluator, list):
|
||||
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
|
||||
else:
|
||||
self.evaluator = evaluator # type: ignore
|
||||
|
@ -8,6 +8,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.hooks import EMAHook
|
||||
from mmengine.model import BaseModel, ExponentialMovingAverage
|
||||
from mmengine.optim import OptimWrapper
|
||||
@ -81,7 +82,7 @@ class TestEMAHook(TestCase):
|
||||
def test_ema_hook(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
model = ToyModel1().to(device)
|
||||
evaluator = Mock()
|
||||
evaluator = Evaluator([])
|
||||
evaluator.evaluate = Mock(return_value=dict(acc=0.5))
|
||||
runner = Runner(
|
||||
model=model,
|
||||
|
@ -558,10 +558,10 @@ class TestRunner(TestCase):
|
||||
param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]),
|
||||
val_cfg=dict(),
|
||||
val_dataloader=val_dataloader,
|
||||
val_evaluator=ToyMetric1(),
|
||||
val_evaluator=[ToyMetric1()],
|
||||
test_cfg=dict(),
|
||||
test_dataloader=test_dataloader,
|
||||
test_evaluator=ToyMetric1(),
|
||||
test_evaluator=[ToyMetric1()],
|
||||
default_hooks=dict(param_scheduler=toy_hook),
|
||||
custom_hooks=[toy_hook2],
|
||||
experiment_name='test_init14')
|
||||
|
Loading…
x
Reference in New Issue
Block a user