add unit test for inverted_residual: debug 2

pull/58/head
johnzja 2020-08-11 20:27:35 +08:00
parent b3bd282cc7
commit db364ef26a
2 changed files with 5 additions and 5 deletions

View File

@ -676,13 +676,13 @@ def test_fastscnn_backbone():
model.init_weights() model.init_weights()
model.train() model.train()
batch_size = 4 batch_size = 4
imgs = torch.randn(batch_size, 3, 1024, 2048) imgs = torch.randn(batch_size, 3, 512, 1024)
feat = model(imgs) feat = model(imgs)
assert len(feat) == 3 assert len(feat) == 3
# higher-res # higher-res
assert feat[0].shape == torch.Size([batch_size, 64, 128, 256]) assert feat[0].shape == torch.Size([batch_size, 64, 64, 128])
# lower-res # lower-res
assert feat[1].shape == torch.Size([batch_size, 128, 32, 64]) assert feat[1].shape == torch.Size([batch_size, 128, 16, 32])
# FFM output # FFM output
assert feat[2].shape == torch.Size([batch_size, 128, 128, 256]) assert feat[2].shape == torch.Size([batch_size, 128, 64, 128])

View File

@ -10,7 +10,7 @@ def test_inv_residual():
# test stride assertion. # test stride assertion.
InvertedResidual(32, 32, 3, 4) InvertedResidual(32, 32, 3, 4)
# test default config without res connection. # test default config with res connection.
# set expand_ratio = 4, stride = 1 and inp=oup. # set expand_ratio = 4, stride = 1 and inp=oup.
inv_module = InvertedResidual(32, 32, 1, 4) inv_module = InvertedResidual(32, 32, 1, 4)
assert inv_module.use_res_connect assert inv_module.use_res_connect