mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
fix linting
This commit is contained in:
parent
fb3934fd2c
commit
85844a3a9e
@ -221,14 +221,14 @@ def test_shufflenetv1_backbone():
|
|||||||
if is_norm(m):
|
if is_norm(m):
|
||||||
assert isinstance(m, _BatchNorm)
|
assert isinstance(m, _BatchNorm)
|
||||||
|
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert len(feat) == 2
|
assert len(feat) == 2
|
||||||
assert feat[0].shape == torch.Size((1, 480, 14, 14))
|
assert feat[0].shape == torch.Size((1, 480, 14, 14))
|
||||||
assert feat[1].shape == torch.Size((1, 960, 7, 7))
|
assert feat[1].shape == torch.Size((1, 960, 7, 7))
|
||||||
|
|
||||||
# Test ShuffleNetv1 forward with layers 2 forward
|
# Test ShuffleNetv1 forward with layers 2 forward
|
||||||
model = ShuffleNetv1(groups=3, out_indices=(2,))
|
model = ShuffleNetv1(groups=3, out_indices=(2, ))
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user