mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Fix unit tests (#449)
This commit is contained in:
parent
429bb27972
commit
e907931fb8
@ -9,7 +9,6 @@ from torch.utils.data import DataLoader
|
|||||||
|
|
||||||
from mmengine.evaluator import Evaluator
|
from mmengine.evaluator import Evaluator
|
||||||
from mmengine.registry import LOOPS
|
from mmengine.registry import LOOPS
|
||||||
from mmengine.utils import is_list_of
|
|
||||||
from .amp import autocast
|
from .amp import autocast
|
||||||
from .base_loop import BaseLoop
|
from .base_loop import BaseLoop
|
||||||
from .utils import calc_dynamic_intervals
|
from .utils import calc_dynamic_intervals
|
||||||
@ -389,7 +388,7 @@ class TestLoop(BaseLoop):
|
|||||||
fp16: bool = False):
|
fp16: bool = False):
|
||||||
super().__init__(runner, dataloader)
|
super().__init__(runner, dataloader)
|
||||||
|
|
||||||
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
if isinstance(evaluator, dict) or isinstance(evaluator, list):
|
||||||
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
|
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
|
||||||
else:
|
else:
|
||||||
self.evaluator = evaluator # type: ignore
|
self.evaluator = evaluator # type: ignore
|
||||||
|
@ -8,6 +8,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from mmengine.evaluator import Evaluator
|
||||||
from mmengine.hooks import EMAHook
|
from mmengine.hooks import EMAHook
|
||||||
from mmengine.model import BaseModel, ExponentialMovingAverage
|
from mmengine.model import BaseModel, ExponentialMovingAverage
|
||||||
from mmengine.optim import OptimWrapper
|
from mmengine.optim import OptimWrapper
|
||||||
@ -81,7 +82,7 @@ class TestEMAHook(TestCase):
|
|||||||
def test_ema_hook(self):
|
def test_ema_hook(self):
|
||||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||||
model = ToyModel1().to(device)
|
model = ToyModel1().to(device)
|
||||||
evaluator = Mock()
|
evaluator = Evaluator([])
|
||||||
evaluator.evaluate = Mock(return_value=dict(acc=0.5))
|
evaluator.evaluate = Mock(return_value=dict(acc=0.5))
|
||||||
runner = Runner(
|
runner = Runner(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -558,10 +558,10 @@ class TestRunner(TestCase):
|
|||||||
param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]),
|
param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]),
|
||||||
val_cfg=dict(),
|
val_cfg=dict(),
|
||||||
val_dataloader=val_dataloader,
|
val_dataloader=val_dataloader,
|
||||||
val_evaluator=ToyMetric1(),
|
val_evaluator=[ToyMetric1()],
|
||||||
test_cfg=dict(),
|
test_cfg=dict(),
|
||||||
test_dataloader=test_dataloader,
|
test_dataloader=test_dataloader,
|
||||||
test_evaluator=ToyMetric1(),
|
test_evaluator=[ToyMetric1()],
|
||||||
default_hooks=dict(param_scheduler=toy_hook),
|
default_hooks=dict(param_scheduler=toy_hook),
|
||||||
custom_hooks=[toy_hook2],
|
custom_hooks=[toy_hook2],
|
||||||
experiment_name='test_init14')
|
experiment_name='test_init14')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user