Fix unit tests (#449)

This commit is contained in:
Mashiro 2022-08-21 14:54:24 +08:00 committed by GitHub
parent 429bb27972
commit e907931fb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 5 additions and 5 deletions

View File

@ -9,7 +9,6 @@ from torch.utils.data import DataLoader
from mmengine.evaluator import Evaluator from mmengine.evaluator import Evaluator
from mmengine.registry import LOOPS from mmengine.registry import LOOPS
from mmengine.utils import is_list_of
from .amp import autocast from .amp import autocast
from .base_loop import BaseLoop from .base_loop import BaseLoop
from .utils import calc_dynamic_intervals from .utils import calc_dynamic_intervals
@ -389,7 +388,7 @@ class TestLoop(BaseLoop):
fp16: bool = False): fp16: bool = False):
super().__init__(runner, dataloader) super().__init__(runner, dataloader)
if isinstance(evaluator, dict) or is_list_of(evaluator, dict): if isinstance(evaluator, dict) or isinstance(evaluator, list):
self.evaluator = runner.build_evaluator(evaluator) # type: ignore self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else: else:
self.evaluator = evaluator # type: ignore self.evaluator = evaluator # type: ignore

View File

@ -8,6 +8,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import Dataset from torch.utils.data import Dataset
from mmengine.evaluator import Evaluator
from mmengine.hooks import EMAHook from mmengine.hooks import EMAHook
from mmengine.model import BaseModel, ExponentialMovingAverage from mmengine.model import BaseModel, ExponentialMovingAverage
from mmengine.optim import OptimWrapper from mmengine.optim import OptimWrapper
@ -81,7 +82,7 @@ class TestEMAHook(TestCase):
def test_ema_hook(self): def test_ema_hook(self):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = ToyModel1().to(device) model = ToyModel1().to(device)
evaluator = Mock() evaluator = Evaluator([])
evaluator.evaluate = Mock(return_value=dict(acc=0.5)) evaluator.evaluate = Mock(return_value=dict(acc=0.5))
runner = Runner( runner = Runner(
model=model, model=model,

View File

@ -558,10 +558,10 @@ class TestRunner(TestCase):
param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]), param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]),
val_cfg=dict(), val_cfg=dict(),
val_dataloader=val_dataloader, val_dataloader=val_dataloader,
val_evaluator=ToyMetric1(), val_evaluator=[ToyMetric1()],
test_cfg=dict(), test_cfg=dict(),
test_dataloader=test_dataloader, test_dataloader=test_dataloader,
test_evaluator=ToyMetric1(), test_evaluator=[ToyMetric1()],
default_hooks=dict(param_scheduler=toy_hook), default_hooks=dict(param_scheduler=toy_hook),
custom_hooks=[toy_hook2], custom_hooks=[toy_hook2],
experiment_name='test_init14') experiment_name='test_init14')