mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* add runner unit tests * update * update * add test custom loop and hook * add test model wrapper * add test setup env * fix typo * fix launcher * fix typo * test default scope * add logger test * fix dataloader * add test loop * resolve comments * resolve comments
67 lines
1.9 KiB
Python
67 lines
1.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
from unittest.mock import Mock
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
from mmengine.runner.loop import (EpochBasedTrainLoop, IterBasedTrainLoop,
|
|
TestLoop, ValLoop)
|
|
|
|
|
|
class ToyDataset(Dataset):
|
|
META = dict() # type: ignore
|
|
data = np.zeros((30, 1, 1, 1))
|
|
|
|
def __len__(self):
|
|
return self.data.shape[0]
|
|
|
|
def __getitem__(self, index):
|
|
return torch.from_numpy(self.data[index])
|
|
|
|
|
|
class TestLoops(TestCase):
|
|
|
|
def setUp(self) -> None:
|
|
self.runner = Mock()
|
|
self.runner.call_hooks = Mock()
|
|
self.runner.model = Mock()
|
|
self.runner.epoch = 0
|
|
self.runner.iter = 0
|
|
self.runner.inner_iter = 0
|
|
self.runner.model.train_step = Mock()
|
|
self.runner.model.val_step = Mock()
|
|
|
|
self.evaluator = Mock()
|
|
self.evaluator.process = Mock()
|
|
self.evaluator.evaluate = Mock()
|
|
|
|
def test_epoch_based_train_loop(self):
|
|
train_loop = EpochBasedTrainLoop(
|
|
runner=self.runner, loader=DataLoader(ToyDataset()), max_epoch=3)
|
|
train_loop.run()
|
|
assert train_loop.runner.epoch == 3
|
|
assert train_loop.runner.iter == 90
|
|
|
|
def test_iter_based_train_loop(self):
|
|
train_loop = IterBasedTrainLoop(
|
|
runner=self.runner, loader=DataLoader(ToyDataset()), max_iter=25)
|
|
train_loop.run()
|
|
assert train_loop.runner.epoch == 0
|
|
assert train_loop.runner.iter == 25
|
|
|
|
def test_val_loop(self):
|
|
val_loop = ValLoop(
|
|
runner=self.runner,
|
|
loader=DataLoader(ToyDataset()),
|
|
evaluator=self.evaluator)
|
|
val_loop.run()
|
|
|
|
def test_test_loop(self):
|
|
test_loop = TestLoop(
|
|
runner=self.runner,
|
|
loader=DataLoader(ToyDataset()),
|
|
evaluator=self.evaluator)
|
|
test_loop.run()
|