add unit test
parent
2ee95c44ce
commit
98f5b49ffe
|
@ -1,5 +1,3 @@
|
||||||
from .mobilenet_v2 import MobileNetv2
|
from .shufflenet_v1 import ShuffleNetv1
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ['ShuffleNetv1']
|
||||||
'MobileNetv2',
|
|
||||||
]
|
|
||||||
|
|
|
@ -4,8 +4,8 @@ from collections import OrderedDict
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint as cp
|
import torch.utils.checkpoint as cp
|
||||||
|
from mmcv.runner import load_checkpoint
|
||||||
|
|
||||||
from ..runner import load_checkpoint
|
|
||||||
from .base_backbone import BaseBackbone
|
from .base_backbone import BaseBackbone
|
||||||
from .weight_init import constant_init, kaiming_init
|
from .weight_init import constant_init, kaiming_init
|
||||||
|
|
||||||
|
@ -28,12 +28,7 @@ def conv1x1(inplanes, planes, groups=1):
|
||||||
- Normal pointwise convolution when groups == 1
|
- Normal pointwise convolution when groups == 1
|
||||||
- Grouped pointwise convolution when groups > 1
|
- Grouped pointwise convolution when groups > 1
|
||||||
"""
|
"""
|
||||||
return nn.Conv2d(
|
return nn.Conv2d(inplanes, planes, kernel_size=1, groups=groups, stride=1)
|
||||||
inplanes,
|
|
||||||
planes,
|
|
||||||
kernel_size=1,
|
|
||||||
groups=groups,
|
|
||||||
stride=1)
|
|
||||||
|
|
||||||
|
|
||||||
def channel_shuffle(x, groups):
|
def channel_shuffle(x, groups):
|
||||||
|
@ -65,8 +60,8 @@ def _make_divisible(v, divisor, min_value=None):
|
||||||
return new_v
|
return new_v
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyShadowingNames,PyShadowingNames
|
|
||||||
class ShuffleUnit(nn.Module):
|
class ShuffleUnit(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
inplanes,
|
inplanes,
|
||||||
planes,
|
planes,
|
||||||
|
@ -96,20 +91,24 @@ class ShuffleUnit(nn.Module):
|
||||||
"Only \"add\" and \"concat\" are "
|
"Only \"add\" and \"concat\" are "
|
||||||
"supported".format(self.combine))
|
"supported".format(self.combine))
|
||||||
|
|
||||||
|
if combine == 'add':
|
||||||
|
assert inplanes == planes, \
|
||||||
|
'inplanes must be equal to outplanes when combine is add'
|
||||||
|
|
||||||
self.first_1x1_groups = self.groups if first_block else 1
|
self.first_1x1_groups = self.groups if first_block else 1
|
||||||
self.g_conv_1x1_compress = self._make_grouped_conv1x1(
|
self.g_conv_1x1_compress = self._make_grouped_conv1x1(
|
||||||
self.inplanes,
|
self.inplanes,
|
||||||
self.bottleneck_channels,
|
self.bottleneck_channels,
|
||||||
self.first_1x1_groups,
|
self.first_1x1_groups,
|
||||||
batch_norm=True,
|
batch_norm=True,
|
||||||
relu=True
|
relu=True)
|
||||||
)
|
|
||||||
|
|
||||||
self.depthwise_conv3x3 = conv3x3(self.bottleneck_channels,
|
self.depthwise_conv3x3 = conv3x3(
|
||||||
|
self.bottleneck_channels,
|
||||||
self.bottleneck_channels,
|
self.bottleneck_channels,
|
||||||
stride=self.depthwise_stride,
|
stride=self.depthwise_stride,
|
||||||
groups=self.bottleneck_channels)
|
groups=self.bottleneck_channels)
|
||||||
self.nn.BatchNorm2d_after_depthwise = \
|
self.bn_after_depthwise = \
|
||||||
nn.BatchNorm2d(self.bottleneck_channels)
|
nn.BatchNorm2d(self.bottleneck_channels)
|
||||||
|
|
||||||
self.g_conv_1x1_expand = \
|
self.g_conv_1x1_expand = \
|
||||||
|
@ -132,8 +131,11 @@ class ShuffleUnit(nn.Module):
|
||||||
return torch.cat((x, out), 1)
|
return torch.cat((x, out), 1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _make_grouped_conv1x1(inplanes, planes, groups,
|
def _make_grouped_conv1x1(inplanes,
|
||||||
batch_norm=True, relu=False):
|
planes,
|
||||||
|
groups,
|
||||||
|
batch_norm=True,
|
||||||
|
relu=False):
|
||||||
|
|
||||||
modules = OrderedDict()
|
modules = OrderedDict()
|
||||||
|
|
||||||
|
@ -150,6 +152,7 @@ class ShuffleUnit(nn.Module):
|
||||||
return conv
|
return conv
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
||||||
def _inner_forward(x):
|
def _inner_forward(x):
|
||||||
residual = x
|
residual = x
|
||||||
|
|
||||||
|
@ -159,7 +162,7 @@ class ShuffleUnit(nn.Module):
|
||||||
out = self.g_conv_1x1_compress(x)
|
out = self.g_conv_1x1_compress(x)
|
||||||
out = channel_shuffle(out, self.groups)
|
out = channel_shuffle(out, self.groups)
|
||||||
out = self.depthwise_conv3x3(out)
|
out = self.depthwise_conv3x3(out)
|
||||||
out = self.nn.BatchNorm2d_after_depthwise(out)
|
out = self.bn_after_depthwise(out)
|
||||||
out = self.g_conv_1x1_expand(out)
|
out = self.g_conv_1x1_expand(out)
|
||||||
|
|
||||||
out = self._combine_func(residual, out)
|
out = self._combine_func(residual, out)
|
||||||
|
@ -230,10 +233,10 @@ class ShuffleNetv1(BaseBackbone):
|
||||||
self.conv1 = conv3x3(3, self.inplanes, stride=2)
|
self.conv1 = conv3x3(3, self.inplanes, stride=2)
|
||||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
self.layer2 = self._make_layer(channels[0], blocks[0],
|
self.layer1 = self._make_layer(
|
||||||
first_block=False, with_cp=with_cp)
|
channels[0], blocks[0], first_block=False, with_cp=with_cp)
|
||||||
self.layer3 = self._make_layer(channels[1], blocks[1], with_cp=with_cp)
|
self.layer2 = self._make_layer(channels[1], blocks[1], with_cp=with_cp)
|
||||||
self.layer4 = self._make_layer(channels[2], blocks[2], with_cp=with_cp)
|
self.layer3 = self._make_layer(channels[2], blocks[2], with_cp=with_cp)
|
||||||
|
|
||||||
def init_weights(self, pretrained=None):
|
def init_weights(self, pretrained=None):
|
||||||
if isinstance(pretrained, str):
|
if isinstance(pretrained, str):
|
||||||
|
@ -248,21 +251,23 @@ class ShuffleNetv1(BaseBackbone):
|
||||||
else:
|
else:
|
||||||
raise TypeError('pretrained must be a str or None')
|
raise TypeError('pretrained must be a str or None')
|
||||||
|
|
||||||
def _make_layer(self,
|
def _make_layer(self, outplanes, blocks, first_block=True, with_cp=False):
|
||||||
outplanes,
|
|
||||||
blocks,
|
|
||||||
first_block=True,
|
|
||||||
with_cp=False):
|
|
||||||
layers = []
|
layers = []
|
||||||
for i in range(blocks):
|
for i in range(blocks):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
layers.append(ShuffleUnit(self.inplanes, outplanes,
|
layers.append(
|
||||||
|
ShuffleUnit(
|
||||||
|
self.inplanes,
|
||||||
|
outplanes,
|
||||||
groups=self.groups,
|
groups=self.groups,
|
||||||
first_block=first_block,
|
first_block=first_block,
|
||||||
combine='concat',
|
combine='concat',
|
||||||
with_cp=with_cp))
|
with_cp=with_cp))
|
||||||
else:
|
else:
|
||||||
layers.append(ShuffleUnit(self.inplanes, outplanes,
|
layers.append(
|
||||||
|
ShuffleUnit(
|
||||||
|
self.inplanes,
|
||||||
|
outplanes,
|
||||||
groups=self.groups,
|
groups=self.groups,
|
||||||
first_block=True,
|
first_block=True,
|
||||||
combine='add',
|
combine='add',
|
||||||
|
@ -274,7 +279,9 @@ class ShuffleNetv1(BaseBackbone):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.maxpool(x)
|
x = self.maxpool(x)
|
||||||
|
|
||||||
outs = []
|
outs = []
|
||||||
|
x = self.layer1(x)
|
||||||
if 0 in self.out_indices:
|
if 0 in self.out_indices:
|
||||||
outs.append(x)
|
outs.append(x)
|
||||||
x = self.layer2(x)
|
x = self.layer2(x)
|
||||||
|
@ -283,8 +290,7 @@ class ShuffleNetv1(BaseBackbone):
|
||||||
x = self.layer3(x)
|
x = self.layer3(x)
|
||||||
if 2 in self.out_indices:
|
if 2 in self.out_indices:
|
||||||
outs.append(x)
|
outs.append(x)
|
||||||
x = self.layer4(x)
|
|
||||||
if 3 in self.out_indices:
|
|
||||||
outs.append(x)
|
outs.append(x)
|
||||||
|
|
||||||
if len(outs) == 1:
|
if len(outs) == 1:
|
||||||
|
|
|
@ -1,16 +1,15 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn.modules import GroupNorm
|
from torch.nn.modules import GroupNorm
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
from mmcls.models.backbones import MobileNetv2
|
from mmcls.models.backbones import ShuffleNetv1
|
||||||
from mmcls.models.backbones.mobilenet_v2 import InvertedResidual
|
from mmcls.models.backbones.shufflenet_v1 import ShuffleUnit
|
||||||
|
|
||||||
|
|
||||||
def is_block(modules):
|
def is_block(modules):
|
||||||
"""Check if is ResNet building block."""
|
"""Check if is ResNet building block."""
|
||||||
if isinstance(modules, (InvertedResidual, )):
|
if isinstance(modules, (ShuffleUnit, )):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -31,62 +30,58 @@ def check_norm_state(modules, train_state):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def test_mobilenetv2_invertedresidual():
|
def test_shufflenetv1_shuffleuint():
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# combine must be in ['add', 'concat']
|
||||||
|
ShuffleUnit(24, 16, groups=3, first_block=True, combine='test')
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# in_channels must be divisible by groups
|
||||||
|
ShuffleUnit(64, 64, groups=3, first_block=True, combine='add')
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
# stride must be in [1, 2]
|
# inplanes must be equal tp = outplanes when combine='add'
|
||||||
InvertedResidual(64, 16, stride=3, expand_ratio=6)
|
ShuffleUnit(64, 24, groups=3, first_block=True, combine='add')
|
||||||
|
|
||||||
# Test InvertedResidual with checkpoint forward, stride=1
|
# Test ShuffleUnit with combine='add'
|
||||||
block = InvertedResidual(64, 16, stride=1, expand_ratio=6)
|
block = ShuffleUnit(24, 24, groups=3, first_block=True, combine='add')
|
||||||
x = torch.randn(1, 64, 56, 56)
|
x = torch.randn(1, 24, 56, 56)
|
||||||
x_out = block(x)
|
x_out = block(x)
|
||||||
assert x_out.shape == torch.Size([1, 16, 56, 56])
|
assert x_out.shape == torch.Size([1, 24, 56, 56])
|
||||||
|
|
||||||
# Test InvertedResidual with checkpoint forward, stride=2
|
# Test ShuffleUnit with combine='concat'
|
||||||
block = InvertedResidual(64, 16, stride=2, expand_ratio=6)
|
block = ShuffleUnit(24, 240, groups=3, first_block=True, combine='concat')
|
||||||
x = torch.randn(1, 64, 56, 56)
|
x = torch.randn(1, 24, 56, 56)
|
||||||
x_out = block(x)
|
x_out = block(x)
|
||||||
assert x_out.shape == torch.Size([1, 16, 28, 28])
|
assert x_out.shape == torch.Size([1, 240, 28, 28])
|
||||||
|
|
||||||
# Test InvertedResidual with checkpoint forward
|
# Test ShuffleUnit with checkpoint forward
|
||||||
block = InvertedResidual(64, 16, stride=1, expand_ratio=6, with_cp=True)
|
block = ShuffleUnit(
|
||||||
assert block.with_cp
|
24, 24, groups=3, first_block=True, combine='add', with_cp=True)
|
||||||
x = torch.randn(1, 64, 56, 56)
|
x = torch.randn(1, 24, 56, 56)
|
||||||
x_out = block(x)
|
x_out = block(x)
|
||||||
assert x_out.shape == torch.Size([1, 16, 56, 56])
|
assert x_out.shape == torch.Size([1, 24, 56, 56])
|
||||||
|
|
||||||
# Test InvertedResidual with activation=nn.ReLU
|
|
||||||
block = InvertedResidual(
|
|
||||||
64, 16, stride=1, expand_ratio=6, activation=nn.ReLU)
|
|
||||||
x = torch.randn(1, 64, 56, 56)
|
|
||||||
x_out = block(x)
|
|
||||||
assert x_out.shape == torch.Size([1, 16, 56, 56])
|
|
||||||
|
|
||||||
|
|
||||||
def test_mobilenetv2_backbone():
|
def test_shufflenetv1_backbone():
|
||||||
with pytest.raises(TypeError):
|
|
||||||
# pretrained must be a string path
|
|
||||||
model = MobileNetv2()
|
|
||||||
model.init_weights(pretrained=0)
|
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(ValueError):
|
||||||
# frozen_stages must less than 7
|
# groups must in [1, 2, 3, 4, 8]
|
||||||
MobileNetv2(frozen_stages=8)
|
ShuffleNetv1(groups=10)
|
||||||
|
|
||||||
# Test MobileNetv2
|
# Test ShuffleNetv1 norm state
|
||||||
model = MobileNetv2()
|
model = ShuffleNetv1()
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
assert check_norm_state(model.modules(), False)
|
assert check_norm_state(model.modules(), False)
|
||||||
|
|
||||||
# Test MobileNetv2 with first stage frozen
|
# Test ShuffleNetv1 with first stage frozen
|
||||||
frozen_stages = 1
|
frozen_stages = 1
|
||||||
model = MobileNetv2(frozen_stages=frozen_stages)
|
model = ShuffleNetv1(frozen_stages=frozen_stages)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
assert model.bn1.training is False
|
for layer in [model.conv1]:
|
||||||
for layer in [model.conv1, model.bn1]:
|
|
||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
assert param.requires_grad is False
|
assert param.requires_grad is False
|
||||||
for i in range(1, frozen_stages + 1):
|
for i in range(1, frozen_stages + 1):
|
||||||
|
@ -97,13 +92,12 @@ def test_mobilenetv2_backbone():
|
||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
assert param.requires_grad is False
|
assert param.requires_grad is False
|
||||||
|
|
||||||
# Test MobileNetv2 with bn frozen
|
# Test ShuffleNetv1 with bn frozen
|
||||||
model = MobileNetv2(bn_frozen=True)
|
model = ShuffleNetv1(bn_frozen=True)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
assert model.bn1.training is False
|
|
||||||
|
|
||||||
for i in range(1, 8):
|
for i in range(1, 4):
|
||||||
layer = getattr(model, f'layer{i}')
|
layer = getattr(model, f'layer{i}')
|
||||||
|
|
||||||
for mod in layer.modules():
|
for mod in layer.modules():
|
||||||
|
@ -112,85 +106,52 @@ def test_mobilenetv2_backbone():
|
||||||
for params in mod.parameters():
|
for params in mod.parameters():
|
||||||
params.requires_grad = False
|
params.requires_grad = False
|
||||||
|
|
||||||
# Test MobileNetv2 forward with widen_factor=1.0
|
# Test ShuffleNetv1 forward with groups=3
|
||||||
model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6)
|
model = ShuffleNetv1(groups=3)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
|
||||||
feat = model(imgs)
|
|
||||||
assert len(feat) == 8
|
|
||||||
assert feat[0].shape == torch.Size([1, 16, 112, 112])
|
|
||||||
assert feat[1].shape == torch.Size([1, 24, 56, 56])
|
|
||||||
assert feat[2].shape == torch.Size([1, 32, 28, 28])
|
|
||||||
assert feat[3].shape == torch.Size([1, 64, 14, 14])
|
|
||||||
assert feat[4].shape == torch.Size([1, 96, 14, 14])
|
|
||||||
assert feat[5].shape == torch.Size([1, 160, 7, 7])
|
|
||||||
assert feat[6].shape == torch.Size([1, 320, 7, 7])
|
|
||||||
|
|
||||||
# Test MobileNetv2 forward with activation=nn.ReLU
|
|
||||||
model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU)
|
|
||||||
model.init_weights()
|
|
||||||
model.train()
|
|
||||||
|
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
|
||||||
feat = model(imgs)
|
|
||||||
assert len(feat) == 8
|
|
||||||
assert feat[0].shape == torch.Size([1, 16, 112, 112])
|
|
||||||
assert feat[1].shape == torch.Size([1, 24, 56, 56])
|
|
||||||
assert feat[2].shape == torch.Size([1, 32, 28, 28])
|
|
||||||
assert feat[3].shape == torch.Size([1, 64, 14, 14])
|
|
||||||
assert feat[4].shape == torch.Size([1, 96, 14, 14])
|
|
||||||
assert feat[5].shape == torch.Size([1, 160, 7, 7])
|
|
||||||
assert feat[6].shape == torch.Size([1, 320, 7, 7])
|
|
||||||
|
|
||||||
# Test MobileNetv2 with BatchNorm forward
|
|
||||||
model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6)
|
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if is_norm(m):
|
if is_norm(m):
|
||||||
assert isinstance(m, _BatchNorm)
|
assert isinstance(m, _BatchNorm)
|
||||||
model.init_weights()
|
|
||||||
model.train()
|
|
||||||
|
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
|
||||||
feat = model(imgs)
|
|
||||||
assert len(feat) == 8
|
|
||||||
assert feat[0].shape == torch.Size([1, 16, 112, 112])
|
|
||||||
assert feat[1].shape == torch.Size([1, 24, 56, 56])
|
|
||||||
assert feat[2].shape == torch.Size([1, 32, 28, 28])
|
|
||||||
assert feat[3].shape == torch.Size([1, 64, 14, 14])
|
|
||||||
assert feat[4].shape == torch.Size([1, 96, 14, 14])
|
|
||||||
assert feat[5].shape == torch.Size([1, 160, 7, 7])
|
|
||||||
assert feat[6].shape == torch.Size([1, 320, 7, 7])
|
|
||||||
|
|
||||||
# Test MobileNetv2 with layers 1, 3, 5 out forward
|
|
||||||
model = MobileNetv2(
|
|
||||||
widen_factor=1.0, activation=nn.ReLU6, out_indices=(0, 2, 4))
|
|
||||||
model.init_weights()
|
|
||||||
model.train()
|
|
||||||
|
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert len(feat) == 4
|
assert len(feat) == 4
|
||||||
assert feat[0].shape == torch.Size([1, 16, 112, 112])
|
assert feat[0].shape == torch.Size([1, 240, 28, 28])
|
||||||
assert feat[1].shape == torch.Size([1, 32, 28, 28])
|
assert feat[1].shape == torch.Size([1, 480, 14, 14])
|
||||||
assert feat[2].shape == torch.Size([1, 96, 14, 14])
|
assert feat[2].shape == torch.Size([1, 960, 7, 7])
|
||||||
|
assert feat[3].shape == torch.Size([1, 960, 7, 7])
|
||||||
|
|
||||||
# Test MobileNetv2 with checkpoint forward
|
# Test ShuffleNetv1 forward with layers 1, 2 forward
|
||||||
model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6, with_cp=True)
|
model = ShuffleNetv1(groups=3, out_indices=(1, 2))
|
||||||
for m in model.modules():
|
|
||||||
if is_block(m):
|
|
||||||
assert m.with_cp
|
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
for m in model.modules():
|
||||||
|
if is_norm(m):
|
||||||
|
assert isinstance(m, _BatchNorm)
|
||||||
|
|
||||||
imgs = torch.randn(1, 3, 224, 224)
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
feat = model(imgs)
|
feat = model(imgs)
|
||||||
assert len(feat) == 8
|
assert len(feat) == 3
|
||||||
assert feat[0].shape == torch.Size([1, 16, 112, 112])
|
assert feat[0].shape == torch.Size([1, 480, 14, 14])
|
||||||
assert feat[1].shape == torch.Size([1, 24, 56, 56])
|
assert feat[1].shape == torch.Size([1, 960, 7, 7])
|
||||||
assert feat[2].shape == torch.Size([1, 32, 28, 28])
|
assert feat[2].shape == torch.Size([1, 960, 7, 7])
|
||||||
assert feat[3].shape == torch.Size([1, 64, 14, 14])
|
|
||||||
assert feat[4].shape == torch.Size([1, 96, 14, 14])
|
# Test ShuffleNetv1 forward with checkpoint forward
|
||||||
assert feat[5].shape == torch.Size([1, 160, 7, 7])
|
model = ShuffleNetv1(groups=3, with_cp=True)
|
||||||
assert feat[6].shape == torch.Size([1, 320, 7, 7])
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
for m in model.modules():
|
||||||
|
if is_norm(m):
|
||||||
|
assert isinstance(m, _BatchNorm)
|
||||||
|
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert len(feat) == 4
|
||||||
|
assert feat[0].shape == torch.Size([1, 240, 28, 28])
|
||||||
|
assert feat[1].shape == torch.Size([1, 480, 14, 14])
|
||||||
|
assert feat[2].shape == torch.Size([1, 960, 7, 7])
|
||||||
|
assert feat[3].shape == torch.Size([1, 960, 7, 7])
|
||||||
|
|
Loading…
Reference in New Issue