23 lines
771 B
Python
23 lines
771 B
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from mmcls.models.backbones import MobileNetv2
|
||
|
|
||
|
|
||
|
def test_mobilenetv2_backbone():
|
||
|
# Test MobileNetv2 with widen_factor 1.0, activation nn.ReLU6
|
||
|
model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6)
|
||
|
model.init_weights()
|
||
|
model.train()
|
||
|
|
||
|
imgs = torch.randn(1, 3, 224, 224)
|
||
|
feat = model(imgs)
|
||
|
assert len(feat) == 8
|
||
|
assert feat[0].shape == torch.Size([1, 16, 112, 112])
|
||
|
assert feat[1].shape == torch.Size([1, 24, 56, 56])
|
||
|
assert feat[2].shape == torch.Size([1, 32, 28, 28])
|
||
|
assert feat[3].shape == torch.Size([1, 64, 14, 14])
|
||
|
assert feat[4].shape == torch.Size([1, 96, 14, 14])
|
||
|
assert feat[5].shape == torch.Size([1, 160, 7, 7])
|
||
|
assert feat[6].shape == torch.Size([1, 320, 7, 7])
|