Merge branch '1-job-failed-481592' into 'master'
Resolve "Job Failed #481592" Closes #1 See merge request open-mmlab/mmclassification!21pull/2/head
commit
6d3a4d12fa
|
@ -5,60 +5,10 @@ from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
|
|||
constant_init, kaiming_init)
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.utils import channel_shuffle, make_divisible
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
def channel_shuffle(x, groups):
|
||||
""" Channel Shuffle operation.
|
||||
|
||||
This function enables cross-group information flow for multiple groups
|
||||
convolution layers.
|
||||
|
||||
Args:
|
||||
x (Tensor): The input tensor.
|
||||
groups (int): The number of groups to divide the input tensor
|
||||
in the channel dimension.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor after channel shuffle operation.
|
||||
"""
|
||||
|
||||
batchsize, num_channels, height, width = x.size()
|
||||
assert (num_channels % groups == 0), ('num_channels should be '
|
||||
'divisible by groups')
|
||||
channels_per_group = num_channels // groups
|
||||
|
||||
x = x.view(batchsize, groups, channels_per_group, height, width)
|
||||
x = torch.transpose(x, 1, 2).contiguous()
|
||||
x = x.view(batchsize, -1, height, width)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def make_divisible(value, divisor, min_value=None):
|
||||
""" Make divisible function.
|
||||
|
||||
This function ensures that all layers have a channel number that is
|
||||
divisible by divisor.
|
||||
|
||||
Args:
|
||||
value (int): The original channel number.
|
||||
divisor (int): The divisor to fully divide the channel number.
|
||||
min_value (int, optional): the minimum value of the output channel.
|
||||
|
||||
Returns:
|
||||
int: The modified output channel number
|
||||
"""
|
||||
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_value < 0.9 * value:
|
||||
new_value += divisor
|
||||
return new_value
|
||||
|
||||
|
||||
class ShuffleUnit(nn.Module):
|
||||
"""ShuffleUnit block.
|
||||
|
||||
|
@ -111,8 +61,8 @@ class ShuffleUnit(nn.Module):
|
|||
if self.combine == 'add':
|
||||
self.depthwise_stride = 1
|
||||
self._combine_func = self._add
|
||||
assert inplanes == planes, ('inplanes must be equal to '
|
||||
'planes when combine is add')
|
||||
assert inplanes == planes, (
|
||||
'inplanes must be equal to planes when combine is add')
|
||||
elif self.combine == 'concat':
|
||||
self.depthwise_stride = 2
|
||||
self._combine_func = self._concat
|
||||
|
@ -270,23 +220,20 @@ class ShuffleNetv1(BaseBackbone):
|
|||
stride=2,
|
||||
padding=1,
|
||||
bias=False)
|
||||
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.layers = []
|
||||
self.layers = nn.ModuleList()
|
||||
for i, num_blocks in enumerate(self.stage_blocks):
|
||||
first_block = False if i == 0 else True
|
||||
layer = self.make_layer(channels[i], num_blocks, first_block)
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.add_module(layer_name, layer)
|
||||
self.layers.append(layer_name)
|
||||
self.layers.append(layer)
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
for param in self.conv1.parameters():
|
||||
param.requires_grad = False
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
layer = getattr(self, f'layer{i}')
|
||||
for i in range(self.frozen_stages):
|
||||
layer = self.layers[i]
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
@ -336,8 +283,7 @@ class ShuffleNetv1(BaseBackbone):
|
|||
x = self.maxpool(x)
|
||||
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer = getattr(self, layer_name)
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
|
|
@ -4,7 +4,7 @@ from torch.nn.modules import GroupNorm
|
|||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import ShuffleNetv1
|
||||
from mmcls.models.backbones.shufflenet_v1 import ShuffleUnit, make_divisible
|
||||
from mmcls.models.backbones.shufflenet_v1 import ShuffleUnit
|
||||
|
||||
|
||||
def is_block(modules):
|
||||
|
@ -30,27 +30,15 @@ def check_norm_state(modules, train_state):
|
|||
return True
|
||||
|
||||
|
||||
def test_make_divisible():
|
||||
# test min_value is None
|
||||
make_divisible(34, 8, None)
|
||||
|
||||
# test new_value < 0.9 * value
|
||||
make_divisible(10, 8, None)
|
||||
|
||||
|
||||
def test_shufflenetv1_shuffleuint():
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# combine must be in ['add', 'concat']
|
||||
ShuffleUnit(24, 16, groups=3, first_block=True, combine='test')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# inplanes must be divisible by groups
|
||||
ShuffleUnit(64, 64, groups=3, first_block=True, combine='add')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# inplanes must be equal tp = outplanes when combine='add'
|
||||
ShuffleUnit(64, 24, groups=3, first_block=True, combine='add')
|
||||
ShuffleUnit(64, 24, groups=4, first_block=True, combine='add')
|
||||
|
||||
# Test ShuffleUnit with combine='add'
|
||||
block = ShuffleUnit(24, 24, groups=3, first_block=True, combine='add')
|
||||
|
@ -104,11 +92,10 @@ def test_shufflenetv1_backbone():
|
|||
model = ShuffleNetv1(frozen_stages=frozen_stages)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
for layer in [model.conv1]:
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(1, frozen_stages + 1):
|
||||
layer = getattr(model, f'layer{i}')
|
||||
for param in model.conv1.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(frozen_stages):
|
||||
layer = model.layers[i]
|
||||
for mod in layer.modules():
|
||||
if isinstance(mod, _BatchNorm):
|
||||
assert mod.training is False
|
||||
|
|
Loading…
Reference in New Issue