From 0667136f054548b771df74c3a8a06b1d31511905 Mon Sep 17 00:00:00 2001 From: lixiaojie Date: Thu, 4 Jun 2020 02:22:53 +0800 Subject: [PATCH 1/4] update test --- mmcls/models/backbones/mobilenet_v2.py | 7 +- tests/test_backbone.py | 206 ++++++++++++++++++++++++- 2 files changed, 209 insertions(+), 4 deletions(-) diff --git a/mmcls/models/backbones/mobilenet_v2.py b/mmcls/models/backbones/mobilenet_v2.py index 57731cf83..44abf2618 100644 --- a/mmcls/models/backbones/mobilenet_v2.py +++ b/mmcls/models/backbones/mobilenet_v2.py @@ -2,8 +2,8 @@ import logging import torch.nn as nn import torch.utils.checkpoint as cp - -from ..runner import load_checkpoint +from mmcv.runner import load_checkpoint +t from .base_backbone import BaseBackbone from .weight_init import constant_init, kaiming_init @@ -154,7 +154,7 @@ class MobileNetv2(BaseBackbone): def __init__(self, widen_factor=1., activation=nn.ReLU6, - out_indices=(0, 1, 2, 3, 4, 5, 6), + out_indices=(0, 1, 2, 3, 4, 5, 6, 7), frozen_stages=-1, bn_eval=True, bn_frozen=False, @@ -177,6 +177,7 @@ class MobileNetv2(BaseBackbone): self.activation = activation(inplace=True) self.out_indices = out_indices + assert frozen_stages <= 7 self.frozen_stages = frozen_stages self.bn_eval = bn_eval self.bn_frozen = bn_frozen diff --git a/tests/test_backbone.py b/tests/test_backbone.py index 0e7819228..838f1b016 100644 --- a/tests/test_backbone.py +++ b/tests/test_backbone.py @@ -1,11 +1,128 @@ +import pytest import torch import torch.nn as nn +from torch.nn.modules import GroupNorm +from torch.nn.modules.batchnorm import _BatchNorm from mmcls.models.backbones import MobileNetv2 +from mmcls.models.backbones.mobilenet_v2 import InvertedResidual + + +def is_block(modules): + """Check if is ResNet building block.""" + if isinstance(modules, (InvertedResidual,)): + return True + return False + + +def is_norm(modules): + """Check if is one of the norms.""" + if isinstance(modules, (GroupNorm, _BatchNorm)): + return True + return False + + +def check_norm_state(modules, train_state): + """Check if norm layer is in correct train state.""" + for mod in modules: + if isinstance(mod, _BatchNorm): + if mod.training != train_state: + return False + return True + + +def test_mobilenetv2_invertedresidual(): + + with pytest.raises(AssertionError): + # stride must be in [1, 2] + InvertedResidual(64, 16, + stride=3, expand_ratio=6) + + # Test InvertedResidual with checkpoint forward, stride=1 + block = InvertedResidual(64, 16, + stride=1, + expand_ratio=6) + x = torch.randn(1, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([1, 16, 56, 56]) + + # Test InvertedResidual with checkpoint forward, stride=2 + block = InvertedResidual(64, 16, + stride=2, + expand_ratio=6) + x = torch.randn(1, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([1, 16, 28, 28]) + + # Test InvertedResidual with checkpoint forward + block = InvertedResidual(64, 16, + stride=1, + expand_ratio=6, + with_cp=True) + assert block.with_cp + x = torch.randn(1, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([1, 16, 56, 56]) + + # Test InvertedResidual with activation=nn.ReLU + block = InvertedResidual(64, 16, + stride=1, + expand_ratio=6, + activation=nn.ReLU) + x = torch.randn(1, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([1, 16, 56, 56]) def test_mobilenetv2_backbone(): - # Test MobileNetv2 with widen_factor 1.0, activation nn.ReLU6 + with pytest.raises(TypeError): + # pretrained must be a string path + model = MobileNetv2() + model.init_weights(pretrained=0) + + with pytest.raises(AssertionError): + # frozen_stages must less than 7 + MobileNetv2(frozen_stages=8) + + # Test MobileNetv2 + model = MobileNetv2() + model.init_weights() + model.train() + assert check_norm_state(model.modules(), False) + + # Test MobileNetv2 with first stage frozen + frozen_stages = 1 + model = MobileNetv2(frozen_stages=frozen_stages) + model.init_weights() + model.train() + assert model.bn1.training is False + for layer in [model.conv1, model.bn1]: + for param in layer.parameters(): + assert param.requires_grad is False + for i in range(1, frozen_stages + 1): + layer = getattr(model, f'layer{i}') + for mod in layer.modules(): + if isinstance(mod, _BatchNorm): + assert mod.training is False + for param in layer.parameters(): + assert param.requires_grad is False + + # Test MobileNetv2 with first stage frozen + model = MobileNetv2(bn_frozen=True) + model.init_weights() + model.train() + assert model.bn1.training is False + + for i in range(1, 8): + layer = getattr(model, f'layer{i}') + + for mod in layer.modules(): + if isinstance(mod, _BatchNorm): + assert mod.training is False + for params in mod.parameters(): + params.requires_grad = False + + # Test MobileNetv2 forward with widen_factor=1.0 model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6) model.init_weights() model.train() @@ -20,3 +137,90 @@ def test_mobilenetv2_backbone(): assert feat[4].shape == torch.Size([1, 96, 14, 14]) assert feat[5].shape == torch.Size([1, 160, 7, 7]) assert feat[6].shape == torch.Size([1, 320, 7, 7]) + + # Test MobileNetv2 forward with activation=nn.ReLU + model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 8 + assert feat[0].shape == torch.Size([1, 16, 112, 112]) + assert feat[1].shape == torch.Size([1, 24, 56, 56]) + assert feat[2].shape == torch.Size([1, 32, 28, 28]) + assert feat[3].shape == torch.Size([1, 64, 14, 14]) + assert feat[4].shape == torch.Size([1, 96, 14, 14]) + assert feat[5].shape == torch.Size([1, 160, 7, 7]) + assert feat[6].shape == torch.Size([1, 320, 7, 7]) + + # Test MobileNetv2 with BatchNorm forward + model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6) + for m in model.modules(): + if is_norm(m): + assert isinstance(m, _BatchNorm) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 8 + assert feat[0].shape == torch.Size([1, 16, 112, 112]) + assert feat[1].shape == torch.Size([1, 24, 56, 56]) + assert feat[2].shape == torch.Size([1, 32, 28, 28]) + assert feat[3].shape == torch.Size([1, 64, 14, 14]) + assert feat[4].shape == torch.Size([1, 96, 14, 14]) + assert feat[5].shape == torch.Size([1, 160, 7, 7]) + assert feat[6].shape == torch.Size([1, 320, 7, 7]) + + # Test MobileNetv2 with BatchNorm forward + model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6) + for m in model.modules(): + if is_norm(m): + assert isinstance(m, _BatchNorm) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 8 + assert feat[0].shape == torch.Size([1, 16, 112, 112]) + assert feat[1].shape == torch.Size([1, 24, 56, 56]) + assert feat[2].shape == torch.Size([1, 32, 28, 28]) + assert feat[3].shape == torch.Size([1, 64, 14, 14]) + assert feat[4].shape == torch.Size([1, 96, 14, 14]) + assert feat[5].shape == torch.Size([1, 160, 7, 7]) + assert feat[6].shape == torch.Size([1, 320, 7, 7]) + + # Test MobileNetv2 with layers 1, 3, 5 out forward + model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6, + out_indices=(0, 2, 4)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 16, 112, 112]) + assert feat[1].shape == torch.Size([1, 32, 28, 28]) + assert feat[2].shape == torch.Size([1, 96, 14, 14]) + + # Test MobileNetv2 with checkpoint forward + model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6, + with_cp=True) + for m in model.modules(): + if is_block(m): + assert m.with_cp + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 8 + assert feat[0].shape == torch.Size([1, 16, 112, 112]) + assert feat[1].shape == torch.Size([1, 24, 56, 56]) + assert feat[2].shape == torch.Size([1, 32, 28, 28]) + assert feat[3].shape == torch.Size([1, 64, 14, 14]) + assert feat[4].shape == torch.Size([1, 96, 14, 14]) + assert feat[5].shape == torch.Size([1, 160, 7, 7]) + assert feat[6].shape == torch.Size([1, 320, 7, 7]) From b392ba44ccdb502f93a15ff0e83ce7ec48881145 Mon Sep 17 00:00:00 2001 From: lixiaojie Date: Thu, 4 Jun 2020 02:25:53 +0800 Subject: [PATCH 2/4] fix --- mmcls/models/backbones/mobilenet_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcls/models/backbones/mobilenet_v2.py b/mmcls/models/backbones/mobilenet_v2.py index 44abf2618..c2b895ab1 100644 --- a/mmcls/models/backbones/mobilenet_v2.py +++ b/mmcls/models/backbones/mobilenet_v2.py @@ -3,7 +3,7 @@ import logging import torch.nn as nn import torch.utils.checkpoint as cp from mmcv.runner import load_checkpoint -t + from .base_backbone import BaseBackbone from .weight_init import constant_init, kaiming_init From fa499114402078b4fba8c7ccc2f1e6a8f4f9940f Mon Sep 17 00:00:00 2001 From: lixiaojie Date: Thu, 4 Jun 2020 02:42:53 +0800 Subject: [PATCH 3/4] reformat --- mmcls/models/backbones/mobilenet_v2.py | 89 +++++++++++++------------- tests/test_backbone.py | 31 +++------ 2 files changed, 56 insertions(+), 64 deletions(-) diff --git a/mmcls/models/backbones/mobilenet_v2.py b/mmcls/models/backbones/mobilenet_v2.py index c2b895ab1..b5262a76f 100644 --- a/mmcls/models/backbones/mobilenet_v2.py +++ b/mmcls/models/backbones/mobilenet_v2.py @@ -22,13 +22,12 @@ def conv3x3(in_planes, out_planes, stride=1, dilation=1): def conv_1x1_bn(inp, oup, activation=nn.ReLU6): return nn.Sequential( - nn.Conv2d(inp, oup, 1, 1, 0, bias=False), - nn.BatchNorm2d(oup), - activation(inplace=True) - ) + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), + activation(inplace=True)) class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, @@ -39,16 +38,15 @@ class ConvBNReLU(nn.Sequential): padding = (kernel_size - 1) // 2 super(ConvBNReLU, self).__init__( - nn.Conv2d(in_planes, - out_planes, - kernel_size, - stride, - padding, - groups=groups, - bias=False), - nn.BatchNorm2d(out_planes), - activation(inplace=True) - ) + nn.Conv2d( + in_planes, + out_planes, + kernel_size, + stride, + padding, + groups=groups, + bias=False), nn.BatchNorm2d(out_planes), + activation(inplace=True)) def _make_divisible(v, divisor, min_value=None): @@ -62,6 +60,7 @@ def _make_divisible(v, divisor, min_value=None): class InvertedResidual(nn.Module): + def __init__(self, inplanes, outplanes, @@ -79,17 +78,18 @@ class InvertedResidual(nn.Module): layers = [] if expand_ratio != 1: # pw - layers.append(ConvBNReLU(inplanes, - hidden_dim, - kernel_size=1, - activation=activation)) + layers.append( + ConvBNReLU( + inplanes, hidden_dim, kernel_size=1, + activation=activation)) layers.extend([ # dw - ConvBNReLU(hidden_dim, - hidden_dim, - stride=stride, - groups=hidden_dim, - activation=activation), + ConvBNReLU( + hidden_dim, + hidden_dim, + stride=stride, + groups=hidden_dim, + activation=activation), # pw-linear nn.Conv2d(hidden_dim, outplanes, 1, 1, 0, bias=False), nn.BatchNorm2d(outplanes), @@ -97,6 +97,7 @@ class InvertedResidual(nn.Module): self.conv = nn.Sequential(*layers) def forward(self, x): + def _inner_forward(x): if self.use_res_connect: return x + self.conv(x) @@ -122,15 +123,23 @@ def make_inverted_res_layer(block, layers = [] for i in range(num_blocks): if i == 0: - layers.append(block(inplanes, planes, stride, - expand_ratio=expand_ratio, - activation=activation, - with_cp=with_cp)) + layers.append( + block( + inplanes, + planes, + stride, + expand_ratio=expand_ratio, + activation=activation, + with_cp=with_cp)) else: - layers.append(block(inplanes, planes, 1, - expand_ratio=expand_ratio, - activation=activation, - with_cp=with_cp)) + layers.append( + block( + inplanes, + planes, + 1, + expand_ratio=expand_ratio, + activation=activation, + with_cp=with_cp)) inplanes = planes return nn.Sequential(*layers) @@ -162,15 +171,10 @@ class MobileNetv2(BaseBackbone): super(MobileNetv2, self).__init__() block = InvertedResidual # expand_ratio, out_channel, n, stride - inverted_residual_setting = [ - [1, 16, 1, 1], - [6, 24, 2, 2], - [6, 32, 3, 2], - [6, 64, 4, 2], - [6, 96, 3, 1], - [6, 160, 3, 2], - [6, 320, 1, 1] - ] + inverted_residual_setting = [[1, 16, 1, 1], [6, 24, 2, + 2], [6, 32, 3, 2], + [6, 64, 4, 2], [6, 96, 3, 1], + [6, 160, 3, 2], [6, 320, 1, 1]] self.widen_factor = widen_factor if isinstance(activation, str): activation = eval(activation) @@ -211,9 +215,8 @@ class MobileNetv2(BaseBackbone): self.out_channel = int(self.out_channel * widen_factor) \ if widen_factor > 1.0 else self.out_channel - self.conv_last = nn.Conv2d(self.inplanes, - self.out_channel, - 1, 1, 0, bias=False) + self.conv_last = nn.Conv2d( + self.inplanes, self.out_channel, 1, 1, 0, bias=False) self.bn_last = nn.BatchNorm2d(self.out_channel) self.feat_dim = self.out_channel diff --git a/tests/test_backbone.py b/tests/test_backbone.py index 838f1b016..171344ab4 100644 --- a/tests/test_backbone.py +++ b/tests/test_backbone.py @@ -10,7 +10,7 @@ from mmcls.models.backbones.mobilenet_v2 import InvertedResidual def is_block(modules): """Check if is ResNet building block.""" - if isinstance(modules, (InvertedResidual,)): + if isinstance(modules, (InvertedResidual, )): return True return False @@ -35,40 +35,30 @@ def test_mobilenetv2_invertedresidual(): with pytest.raises(AssertionError): # stride must be in [1, 2] - InvertedResidual(64, 16, - stride=3, expand_ratio=6) + InvertedResidual(64, 16, stride=3, expand_ratio=6) # Test InvertedResidual with checkpoint forward, stride=1 - block = InvertedResidual(64, 16, - stride=1, - expand_ratio=6) + block = InvertedResidual(64, 16, stride=1, expand_ratio=6) x = torch.randn(1, 64, 56, 56) x_out = block(x) assert x_out.shape == torch.Size([1, 16, 56, 56]) # Test InvertedResidual with checkpoint forward, stride=2 - block = InvertedResidual(64, 16, - stride=2, - expand_ratio=6) + block = InvertedResidual(64, 16, stride=2, expand_ratio=6) x = torch.randn(1, 64, 56, 56) x_out = block(x) assert x_out.shape == torch.Size([1, 16, 28, 28]) # Test InvertedResidual with checkpoint forward - block = InvertedResidual(64, 16, - stride=1, - expand_ratio=6, - with_cp=True) + block = InvertedResidual(64, 16, stride=1, expand_ratio=6, with_cp=True) assert block.with_cp x = torch.randn(1, 64, 56, 56) x_out = block(x) assert x_out.shape == torch.Size([1, 16, 56, 56]) # Test InvertedResidual with activation=nn.ReLU - block = InvertedResidual(64, 16, - stride=1, - expand_ratio=6, - activation=nn.ReLU) + block = InvertedResidual( + 64, 16, stride=1, expand_ratio=6, activation=nn.ReLU) x = torch.randn(1, 64, 56, 56) x_out = block(x) assert x_out.shape == torch.Size([1, 16, 56, 56]) @@ -193,8 +183,8 @@ def test_mobilenetv2_backbone(): assert feat[6].shape == torch.Size([1, 320, 7, 7]) # Test MobileNetv2 with layers 1, 3, 5 out forward - model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6, - out_indices=(0, 2, 4)) + model = MobileNetv2( + widen_factor=1.0, activation=nn.ReLU6, out_indices=(0, 2, 4)) model.init_weights() model.train() @@ -206,8 +196,7 @@ def test_mobilenetv2_backbone(): assert feat[2].shape == torch.Size([1, 96, 14, 14]) # Test MobileNetv2 with checkpoint forward - model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6, - with_cp=True) + model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6, with_cp=True) for m in model.modules(): if is_block(m): assert m.with_cp From 0def3a56b647e28009ad242966e386f2b50827ff Mon Sep 17 00:00:00 2001 From: lixiaojie Date: Fri, 5 Jun 2020 15:52:22 +0800 Subject: [PATCH 4/4] fix test --- tests/test_backbone.py | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/tests/test_backbone.py b/tests/test_backbone.py index 171344ab4..3f88802f2 100644 --- a/tests/test_backbone.py +++ b/tests/test_backbone.py @@ -97,7 +97,7 @@ def test_mobilenetv2_backbone(): for param in layer.parameters(): assert param.requires_grad is False - # Test MobileNetv2 with first stage frozen + # Test MobileNetv2 with bn frozen model = MobileNetv2(bn_frozen=True) model.init_weights() model.train() @@ -163,25 +163,6 @@ def test_mobilenetv2_backbone(): assert feat[5].shape == torch.Size([1, 160, 7, 7]) assert feat[6].shape == torch.Size([1, 320, 7, 7]) - # Test MobileNetv2 with BatchNorm forward - model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6) - for m in model.modules(): - if is_norm(m): - assert isinstance(m, _BatchNorm) - model.init_weights() - model.train() - - imgs = torch.randn(1, 3, 224, 224) - feat = model(imgs) - assert len(feat) == 8 - assert feat[0].shape == torch.Size([1, 16, 112, 112]) - assert feat[1].shape == torch.Size([1, 24, 56, 56]) - assert feat[2].shape == torch.Size([1, 32, 28, 28]) - assert feat[3].shape == torch.Size([1, 64, 14, 14]) - assert feat[4].shape == torch.Size([1, 96, 14, 14]) - assert feat[5].shape == torch.Size([1, 160, 7, 7]) - assert feat[6].shape == torch.Size([1, 320, 7, 7]) - # Test MobileNetv2 with layers 1, 3, 5 out forward model = MobileNetv2( widen_factor=1.0, activation=nn.ReLU6, out_indices=(0, 2, 4))