fix unit test
parent
8060e0f620
commit
66e9bb017d
|
@ -220,7 +220,6 @@ class ShuffleNetv1(BaseBackbone):
|
||||||
stride=2,
|
stride=2,
|
||||||
padding=1,
|
padding=1,
|
||||||
bias=False)
|
bias=False)
|
||||||
|
|
||||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
self.layers = nn.ModuleList()
|
self.layers = nn.ModuleList()
|
||||||
|
|
|
@ -36,13 +36,9 @@ def test_shufflenetv1_shuffleuint():
|
||||||
# combine must be in ['add', 'concat']
|
# combine must be in ['add', 'concat']
|
||||||
ShuffleUnit(24, 16, groups=3, first_block=True, combine='test')
|
ShuffleUnit(24, 16, groups=3, first_block=True, combine='test')
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
# inplanes must be divisible by groups
|
|
||||||
ShuffleUnit(64, 64, groups=3, first_block=True, combine='add')
|
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
# inplanes must be equal tp = outplanes when combine='add'
|
# inplanes must be equal tp = outplanes when combine='add'
|
||||||
ShuffleUnit(64, 24, groups=3, first_block=True, combine='add')
|
ShuffleUnit(64, 24, groups=4, first_block=True, combine='add')
|
||||||
|
|
||||||
# Test ShuffleUnit with combine='add'
|
# Test ShuffleUnit with combine='add'
|
||||||
block = ShuffleUnit(24, 24, groups=3, first_block=True, combine='add')
|
block = ShuffleUnit(24, 24, groups=3, first_block=True, combine='add')
|
||||||
|
|
Loading…
Reference in New Issue