mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* add mytrain.py for test * test before layers * test attr in layers * test classifier * delete mytrain.py * add patchembed and hybridembed * add patchembed and hybridembed to __init__ * test patchembed and hybridembed * fix some comments
190 lines
6.3 KiB
Python
190 lines
6.3 KiB
Python
import pytest
|
|
import torch
|
|
from torch.nn.modules import GroupNorm
|
|
from torch.nn.modules.batchnorm import _BatchNorm
|
|
|
|
from mmcls.models.backbones import VGG
|
|
from mmcls.models.utils import (Augments, HybridEmbed, InvertedResidual,
|
|
PatchEmbed, SELayer, channel_shuffle,
|
|
make_divisible)
|
|
|
|
|
|
def is_norm(modules):
|
|
"""Check if is one of the norms."""
|
|
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
|
return True
|
|
return False
|
|
|
|
|
|
def test_make_divisible():
|
|
# test min_value is None
|
|
result = make_divisible(34, 8, None)
|
|
assert result == 32
|
|
|
|
# test when new_value > min_ratio * value
|
|
result = make_divisible(10, 8, min_ratio=0.9)
|
|
assert result == 16
|
|
|
|
# test min_value = 0.8
|
|
result = make_divisible(33, 8, min_ratio=0.8)
|
|
assert result == 32
|
|
|
|
|
|
def test_channel_shuffle():
|
|
x = torch.randn(1, 24, 56, 56)
|
|
with pytest.raises(AssertionError):
|
|
# num_channels should be divisible by groups
|
|
channel_shuffle(x, 7)
|
|
|
|
groups = 3
|
|
batch_size, num_channels, height, width = x.size()
|
|
channels_per_group = num_channels // groups
|
|
out = channel_shuffle(x, groups)
|
|
# test the output value when groups = 3
|
|
for b in range(batch_size):
|
|
for c in range(num_channels):
|
|
c_out = c % channels_per_group * groups + c // channels_per_group
|
|
for i in range(height):
|
|
for j in range(width):
|
|
assert x[b, c, i, j] == out[b, c_out, i, j]
|
|
|
|
|
|
def test_inverted_residual():
|
|
|
|
with pytest.raises(AssertionError):
|
|
# stride must be in [1, 2]
|
|
InvertedResidual(16, 16, 32, stride=3)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# se_cfg must be None or dict
|
|
InvertedResidual(16, 16, 32, se_cfg=list())
|
|
|
|
# Add expand conv if in_channels and mid_channels is not the same
|
|
assert InvertedResidual(32, 16, 32).with_expand_conv is False
|
|
assert InvertedResidual(16, 16, 32).with_expand_conv is True
|
|
|
|
# Test InvertedResidual forward, stride=1
|
|
block = InvertedResidual(16, 16, 32, stride=1)
|
|
x = torch.randn(1, 16, 56, 56)
|
|
x_out = block(x)
|
|
assert getattr(block, 'se', None) is None
|
|
assert block.with_res_shortcut
|
|
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
|
|
|
# Test InvertedResidual forward, stride=2
|
|
block = InvertedResidual(16, 16, 32, stride=2)
|
|
x = torch.randn(1, 16, 56, 56)
|
|
x_out = block(x)
|
|
assert not block.with_res_shortcut
|
|
assert x_out.shape == torch.Size((1, 16, 28, 28))
|
|
|
|
# Test InvertedResidual forward with se layer
|
|
se_cfg = dict(channels=32)
|
|
block = InvertedResidual(16, 16, 32, stride=1, se_cfg=se_cfg)
|
|
x = torch.randn(1, 16, 56, 56)
|
|
x_out = block(x)
|
|
assert isinstance(block.se, SELayer)
|
|
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
|
|
|
# Test InvertedResidual forward without expand conv
|
|
block = InvertedResidual(32, 16, 32)
|
|
x = torch.randn(1, 32, 56, 56)
|
|
x_out = block(x)
|
|
assert getattr(block, 'expand_conv', None) is None
|
|
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
|
|
|
# Test InvertedResidual forward with GroupNorm
|
|
block = InvertedResidual(
|
|
16, 16, 32, norm_cfg=dict(type='GN', num_groups=2))
|
|
x = torch.randn(1, 16, 56, 56)
|
|
x_out = block(x)
|
|
for m in block.modules():
|
|
if is_norm(m):
|
|
assert isinstance(m, GroupNorm)
|
|
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
|
|
|
# Test InvertedResidual forward with HSigmoid
|
|
block = InvertedResidual(16, 16, 32, act_cfg=dict(type='HSigmoid'))
|
|
x = torch.randn(1, 16, 56, 56)
|
|
x_out = block(x)
|
|
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
|
|
|
# Test InvertedResidual forward with checkpoint
|
|
block = InvertedResidual(16, 16, 32, with_cp=True)
|
|
x = torch.randn(1, 16, 56, 56)
|
|
x_out = block(x)
|
|
assert block.with_cp
|
|
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
|
|
|
|
|
def test_augments():
|
|
imgs = torch.randn(4, 3, 32, 32)
|
|
labels = torch.randint(0, 10, (4, ))
|
|
|
|
# Test cutmix
|
|
augments_cfg = dict(type='BatchCutMix', alpha=1., num_classes=10, prob=1.)
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
# Test mixup
|
|
augments_cfg = dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.)
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
# Test cutmixup
|
|
augments_cfg = [
|
|
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
|
|
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3)
|
|
]
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
augments_cfg = [
|
|
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
|
|
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.5)
|
|
]
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
augments_cfg = [
|
|
dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
|
|
dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3),
|
|
dict(type='Identity', num_classes=10, prob=0.2)
|
|
]
|
|
augs = Augments(augments_cfg)
|
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
|
assert mixed_labels.shape == torch.Size((4, 10))
|
|
|
|
|
|
def test_embed():
|
|
|
|
# Test PatchEmbed
|
|
patch_embed = PatchEmbed()
|
|
img = torch.randn(1, 3, 224, 224)
|
|
img = patch_embed(img)
|
|
assert img.shape == torch.Size((1, 196, 768))
|
|
|
|
# Test PatchEmbed with stride = 8
|
|
conv_cfg = dict(kernel_size=16, stride=8)
|
|
patch_embed = PatchEmbed(conv_cfg=conv_cfg)
|
|
img = torch.randn(1, 3, 224, 224)
|
|
img = patch_embed(img)
|
|
assert img.shape == torch.Size((1, 729, 768))
|
|
|
|
# Test VGG11 HybridEmbed
|
|
backbone = VGG(11, norm_eval=True)
|
|
backbone.init_weights()
|
|
patch_embed = HybridEmbed(backbone)
|
|
img = torch.randn(1, 3, 224, 224)
|
|
img = patch_embed(img)
|
|
assert img.shape == torch.Size((1, 49, 768))
|