mmengine/tests/test_runner/test_loop.py
RangiLyu 64b1d183b9
Add runner unit tests. (#68)
* add runner unit tests

* update

* update

* add test custom loop and hook

* add test model wrapper

* add test setup env

* fix typo

* fix launcher

* fix typo

* test default scope

* add logger test

* fix dataloader

* add test loop

* resolve comments

* resolve comments
2022-03-03 19:44:36 +08:00

67 lines
1.9 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from mmengine.runner.loop import (EpochBasedTrainLoop, IterBasedTrainLoop,
TestLoop, ValLoop)
class ToyDataset(Dataset):
META = dict() # type: ignore
data = np.zeros((30, 1, 1, 1))
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
return torch.from_numpy(self.data[index])
class TestLoops(TestCase):
def setUp(self) -> None:
self.runner = Mock()
self.runner.call_hooks = Mock()
self.runner.model = Mock()
self.runner.epoch = 0
self.runner.iter = 0
self.runner.inner_iter = 0
self.runner.model.train_step = Mock()
self.runner.model.val_step = Mock()
self.evaluator = Mock()
self.evaluator.process = Mock()
self.evaluator.evaluate = Mock()
def test_epoch_based_train_loop(self):
train_loop = EpochBasedTrainLoop(
runner=self.runner, loader=DataLoader(ToyDataset()), max_epoch=3)
train_loop.run()
assert train_loop.runner.epoch == 3
assert train_loop.runner.iter == 90
def test_iter_based_train_loop(self):
train_loop = IterBasedTrainLoop(
runner=self.runner, loader=DataLoader(ToyDataset()), max_iter=25)
train_loop.run()
assert train_loop.runner.epoch == 0
assert train_loop.runner.iter == 25
def test_val_loop(self):
val_loop = ValLoop(
runner=self.runner,
loader=DataLoader(ToyDataset()),
evaluator=self.evaluator)
val_loop.run()
def test_test_loop(self):
test_loop = TestLoop(
runner=self.runner,
loader=DataLoader(ToyDataset()),
evaluator=self.evaluator)
test_loop.run()