mmclassification/tests/test_backbone.py

23 lines
771 B
Python
Raw Normal View History

2020-06-03 15:51:17 +08:00
import torch
import torch.nn as nn
from mmcls.models.backbones import MobileNetv2
def test_mobilenetv2_backbone():
# Test MobileNetv2 with widen_factor 1.0, activation nn.ReLU6
model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 8
assert feat[0].shape == torch.Size([1, 16, 112, 112])
assert feat[1].shape == torch.Size([1, 24, 56, 56])
assert feat[2].shape == torch.Size([1, 32, 28, 28])
assert feat[3].shape == torch.Size([1, 64, 14, 14])
assert feat[4].shape == torch.Size([1, 96, 14, 14])
assert feat[5].shape == torch.Size([1, 160, 7, 7])
assert feat[6].shape == torch.Size([1, 320, 7, 7])