mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
fix mmdp unittest (#60)
This commit is contained in:
parent
bc759e5550
commit
5170676a2f
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -57,7 +58,7 @@ def test_is_model_wrapper():
|
||||
assert is_model_wrapper(model_wrapper)
|
||||
|
||||
|
||||
class TestMMDataParallel:
|
||||
class TestMMDataParallel(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Setup the demo image in every test method.
|
||||
@ -70,7 +71,7 @@ class TestMMDataParallel:
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(2, 2, 1)
|
||||
self.conv = nn.Conv2d(1, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
@ -101,7 +102,7 @@ class TestMMDataParallel:
|
||||
with pytest.raises(AssertionError):
|
||||
mmdp.train_step(torch.zeros([1, 1, 3, 3]))
|
||||
|
||||
out = self.model.train_step([torch.zeros([1, 1, 3, 3])])
|
||||
out = self.model.train_step(torch.zeros([1, 1, 3, 3]))
|
||||
assert out.shape == (1, 2, 3, 3)
|
||||
|
||||
def test_val_step(self):
|
||||
@ -122,5 +123,5 @@ class TestMMDataParallel:
|
||||
with pytest.raises(AssertionError):
|
||||
mmdp.val_step(torch.zeros([1, 1, 3, 3]))
|
||||
|
||||
out = self.model.val_step([torch.zeros([1, 1, 3, 3])])
|
||||
out = self.model.val_step(torch.zeros([1, 1, 3, 3]))
|
||||
assert out.shape == (1, 2, 3, 3)
|
||||
|
Loading…
x
Reference in New Issue
Block a user