mmengine/tests/test_runner/test_loop.py

67 lines
1.9 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from mmengine.runner.loop import (EpochBasedTrainLoop, IterBasedTrainLoop,
TestLoop, ValLoop)
class ToyDataset(Dataset):
META = dict() # type: ignore
data = np.zeros((30, 1, 1, 1))
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
return torch.from_numpy(self.data[index])
class TestLoops(TestCase):
def setUp(self) -> None:
self.runner = Mock()
self.runner.call_hooks = Mock()
self.runner.model = Mock()
self.runner.epoch = 0
self.runner.iter = 0
self.runner.inner_iter = 0
self.runner.model.train_step = Mock()
self.runner.model.val_step = Mock()
self.evaluator = Mock()
self.evaluator.process = Mock()
self.evaluator.evaluate = Mock()
def test_epoch_based_train_loop(self):
train_loop = EpochBasedTrainLoop(
runner=self.runner, loader=DataLoader(ToyDataset()), max_epoch=3)
train_loop.run()
assert train_loop.runner.epoch == 3
assert train_loop.runner.iter == 90
def test_iter_based_train_loop(self):
train_loop = IterBasedTrainLoop(
runner=self.runner, loader=DataLoader(ToyDataset()), max_iter=25)
train_loop.run()
assert train_loop.runner.epoch == 0
assert train_loop.runner.iter == 25
def test_val_loop(self):
val_loop = ValLoop(
runner=self.runner,
loader=DataLoader(ToyDataset()),
evaluator=self.evaluator)
val_loop.run()
def test_test_loop(self):
test_loop = TestLoop(
runner=self.runner,
loader=DataLoader(ToyDataset()),
evaluator=self.evaluator)
test_loop.run()