mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* Refactor Mobilenetv3 structure and add ConvClsHead. * Change model's name from 'MobileNetv3' to 'MobileNetV3' * Modify configs for MobileNetV3 on CIFAR10. And add MobileNetV3 configs for imagenet * Fix activate setting bugs in MobileNetV3. And remove bias in SELayer. * Modify unittest * Remove useless config and file. * Fix mobilenetv3-large arch setting * Add dropout option in ConvClsHead * Fix MobilenetV3 structure according to torchvision version. 1. Remove with_expand_conv option in InvertedResidual, it should be decided by channels. 2. Revert activation function, should before SE layer. * Format code. * Rename MobilenetV3 arch "big" to "large". * Add mobilenetv3_small torchvision training recipe * Modify default `out_indices` of MobilenetV3, now it will change according to `arch` if not specified. * Add MobilenetV3 large config. * Add mobilenetv3 README * Modify InvertedResidual unit test. * Refactor ConvClsHead to StackedLinearClsHead, and add unit tests. * Add unit test for `simple_test` of `StackedLinearClsHead`. * Fix typo Co-authored-by: Yidi Shao <ydshao@smail.nju.edu.cn>
164 lines
5.5 KiB
Python
164 lines
5.5 KiB
Python
import pytest
|
|
import torch
|
|
from torch.nn.modules import GroupNorm
|
|
from torch.nn.modules.batchnorm import _BatchNorm
|
|
|
|
from mmcls.models.utils import (Augments, InvertedResidual, SELayer,
|
|
channel_shuffle, make_divisible)
|
|
|
|
|
|
def is_norm(modules):
|
|
"""Check if is one of the norms."""
|
|
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
|
return True
|
|
return False
|
|
|
|
|
|
def test_make_divisible():
|
|
# test min_value is None
|
|
result = make_divisible(34, 8, None)
|
|
assert result == 32
|
|
|
|
# test when new_value > min_ratio * value
|
|
result = make_divisible(10, 8, min_ratio=0.9)
|
|
assert result == 16
|
|
|
|
# test min_value = 0.8
|
|
result = make_divisible(33, 8, min_ratio=0.8)
|
|
assert result == 32
|
|
|
|
|
|
def test_channel_shuffle():
|
|
x = torch.randn(1, 24, 56, 56)
|
|
with pytest.raises(AssertionError):
|
|
# num_channels should be divisible by groups
|
|
channel_shuffle(x, 7)
|
|
|
|
groups = 3
|
|
batch_size, num_channels, height, width = x.size()
|
|
channels_per_group = num_channels // groups
|
|
out = channel_shuffle(x, groups)
|
|
# test the output value when groups = 3
|
|
for b in range(batch_size):
|
|
for c in range(num_channels):
|
|
c_out = c % channels_per_group * groups + c // channels_per_group
|
|
for i in range(height):
|
|
for j in range(width):
|
|
assert x[b, c, i, j] == out[b, c_out, i, j]
|
|
|
|
|
|
def test_inverted_residual():
|
|
|
|
with pytest.raises(AssertionError):
|
|
# stride must be in [1, 2]
|
|
InvertedResidual(16, 16, 32, stride=3)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# se_cfg must be None or dict
|
|
InvertedResidual(16, 16, 32, se_cfg=list())
|
|
|
|
# Add expand conv if in_channels and mid_channels is not the same
|
|
assert InvertedResidual(32, 16, 32).with_expand_conv is False
|
|
assert InvertedResidual(16, 16, 32).with_expand_conv is True
|
|
|
|
# Test InvertedResidual forward, stride=1
|
|
block = InvertedResidual(16, 16, 32, stride=1)
|
|
x = torch.randn(1, 16, 56, 56)
|
|
x_out = block(x)
|
|
assert getattr(block, 'se', None) is None
|
|
assert block.with_res_shortcut
|
|
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
|
|
|
# Test InvertedResidual forward, stride=2
|
|
block = InvertedResidual(16, 16, 32, stride=2)
|
|
x = torch.randn(1, 16, 56, 56)
|
|
x_out = block(x)
|
|
assert not block.with_res_shortcut
|
|
assert x_out.shape == torch.Size((1, 16, 28, 28))
|
|
|
|
# Test InvertedResidual forward with se layer
|
|
se_cfg = dict(channels=32)
|
|
block = InvertedResidual(16, 16, 32, stride=1, se_cfg=se_cfg)
|
|
x = torch.randn(1, 16, 56, 56)
|
|
x_out = block(x)
|
|
assert isinstance(block.se, SELayer)
|
|
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
|
|
|
# Test InvertedResidual forward without expand conv
|
|
block = InvertedResidual(32, 16, 32)
|
|
x = torch.randn(1, 32, 56, 56)
|
|
x_out = block(x)
|
|
assert getattr(block, 'expand_conv', None) is None
|
|
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
|
|
|
# Test InvertedResidual forward with GroupNorm
|
|
block = InvertedResidual(
|
|
16, 16, 32, norm_cfg=dict(type='GN', num_groups=2))
|
|
x = torch.randn(1, 16, 56, 56)
|
|
x_out = block(x)
|
|
for m in block.modules():
|
|
if is_norm(m):
|
|
assert isinstance(m, GroupNorm)
|
|
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
|
|
|
# Test InvertedResidual forward with HSigmoid
|
|
block = InvertedResidual(16, 16, 32, act_cfg=dict(type='HSigmoid'))
|
|
x = torch.randn(1, 16, 56, 56)
|
|
x_out = block(x)
|
|
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
|
|
|
# Test InvertedResidual forward with checkpoint
|
|
block = InvertedResidual(16, 16, 32, with_cp=True)
|
|
x = torch.randn(1, 16, 56, 56)
|
|
x_out = block(x)
|
|
assert block.with_cp
|
|
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
|
|
|
|
|
def test_augments():
|
|
imgs = torch.randn(4, 3, 32, 32)
|
|
labels = torch.randint(0, 10, (4, ))
|
|
|
|
# Test cutmix
|
|
augments_cfg = dict(type='BatchCutMix', alpha=1., num_classes=10, prob=1.)
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
# Test mixup
|
|
augments_cfg = dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.)
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
# Test cutmixup
|
|
augments_cfg = [
|
|
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
|
|
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3)
|
|
]
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
augments_cfg = [
|
|
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
|
|
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.5)
|
|
]
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
augments_cfg = [
|
|
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
|
|
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3),
|
|
dict(type='Identity', num_classes=10, prob=0.2)
|
|
]
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|