mmpretrain/tests/test_models/test_backbones/test_resnext.py

62 lines
2.0 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
2020-06-14 12:08:37 +08:00
import pytest
import torch
from mmcls.models.backbones import ResNeXt
from mmcls.models.backbones.resnext import Bottleneck as BottleneckX
2020-06-25 11:57:50 +08:00
def test_bottleneck():
2020-06-14 12:08:37 +08:00
with pytest.raises(AssertionError):
# Style must be in ['pytorch', 'caffe']
2020-06-25 11:57:50 +08:00
BottleneckX(64, 64, groups=32, width_per_group=4, style='tensorflow')
2020-06-14 12:08:37 +08:00
# Test ResNeXt Bottleneck structure
block = BottleneckX(
2020-06-25 11:57:50 +08:00
64, 256, groups=32, width_per_group=4, stride=2, style='pytorch')
2020-06-14 12:08:37 +08:00
assert block.conv2.stride == (2, 2)
assert block.conv2.groups == 32
assert block.conv2.out_channels == 128
# Test ResNeXt Bottleneck forward
2020-06-25 11:57:50 +08:00
block = BottleneckX(64, 64, base_channels=16, groups=32, width_per_group=4)
2020-06-14 12:08:37 +08:00
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
2020-06-25 11:57:50 +08:00
def test_resnext():
2020-06-14 12:08:37 +08:00
with pytest.raises(KeyError):
# ResNeXt depth should be in [50, 101, 152]
ResNeXt(depth=18)
2020-06-25 11:57:50 +08:00
# Test ResNeXt with group 32, width_per_group 4
2020-06-14 12:08:37 +08:00
model = ResNeXt(
2020-06-25 11:57:50 +08:00
depth=50, groups=32, width_per_group=4, out_indices=(0, 1, 2, 3))
2020-06-14 12:08:37 +08:00
for m in model.modules():
2020-06-25 11:57:50 +08:00
if isinstance(m, BottleneckX):
2020-06-14 12:08:37 +08:00
assert m.conv2.groups == 32
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([1, 256, 56, 56])
assert feat[1].shape == torch.Size([1, 512, 28, 28])
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
2020-06-25 11:57:50 +08:00
# Test ResNeXt with group 32, width_per_group 4 and layers 3 out forward
model = ResNeXt(depth=50, groups=32, width_per_group=4, out_indices=(3, ))
2020-06-14 12:08:37 +08:00
for m in model.modules():
2020-06-25 11:57:50 +08:00
if isinstance(m, BottleneckX):
2020-06-14 12:08:37 +08:00
assert m.conv2.groups == 32
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([1, 2048, 7, 7])