fast_scnn test: fix BN bug.
parent
8df067e15d
commit
54ae2eeb5f
|
@ -675,13 +675,14 @@ def test_fastscnn_backbone():
|
|||
model = FastSCNN()
|
||||
model.init_weights()
|
||||
model.train()
|
||||
imgs = torch.randn(1, 3, 1024, 2048)
|
||||
num_batch_picts = 4
|
||||
imgs = torch.randn(num_batch_picts, 3, 1024, 2048)
|
||||
feat = model(imgs)
|
||||
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == torch.Size([1, 64, 128, 256]) # higher-res
|
||||
assert feat[1].shape == torch.Size([1, 128, 32, 64]) # lower-res
|
||||
assert feat[2].shape == torch.Size([1, 128, 128, 256]) # FFM output
|
||||
assert feat[0].shape == torch.Size([num_batch_picts, 64, 128, 256]) # higher-res
|
||||
assert feat[1].shape == torch.Size([num_batch_picts, 128, 32, 64]) # lower-res
|
||||
assert feat[2].shape == torch.Size([num_batch_picts, 128, 128, 256]) # FFM output
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue