fix linting

This commit is contained in:
lixiaojie 2020-06-14 01:25:08 +08:00
parent fb3934fd2c
commit 85844a3a9e

View File

@ -221,14 +221,14 @@ def test_shufflenetv1_backbone():
if is_norm(m): if is_norm(m):
assert isinstance(m, _BatchNorm) assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 224, 224) imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs) feat = model(imgs)
assert len(feat) == 2 assert len(feat) == 2
assert feat[0].shape == torch.Size((1, 480, 14, 14)) assert feat[0].shape == torch.Size((1, 480, 14, 14))
assert feat[1].shape == torch.Size((1, 960, 7, 7)) assert feat[1].shape == torch.Size((1, 960, 7, 7))
# Test ShuffleNetv1 forward with layers 2 forward # Test ShuffleNetv1 forward with layers 2 forward
model = ShuffleNetv1(groups=3, out_indices=(2,)) model = ShuffleNetv1(groups=3, out_indices=(2, ))
model.init_weights() model.init_weights()
model.train() model.train()