Merge branch 'master' of gitlab.sz.sensetime.com:open-mmlab/mmclassification into dev_shufflenetv2

pull/2/head
lixiaojie 2020-06-15 20:42:22 +08:00
commit cf85db0f7c
9 changed files with 112 additions and 9 deletions

View File

@ -1,4 +1,6 @@
image: registry.sensetime.com/eig-research/pytorch:1.3.1-cuda10.1-cudnn7-devel
variables:
PARROTS_IMAGE: registry.sensetime.com/platform/product:pat20200612
PYTORCH_IMAGE: registry.sensetime.com/eig-research/pytorch:1.3.1-cuda10.1-cudnn7-devel
stages:
- linting
@ -14,6 +16,7 @@ before_script:
- python -c "import torch; print(torch.__version__)"
linting:
image: $PYTORCH_IMAGE
stage: linting
script:
- pip install flake8 yapf isort
@ -21,14 +24,22 @@ linting:
- isort -rc --check-only --diff mmcls/ tools/ tests/
- yapf -r -d mmcls/ tools/ tests/ configs/
test:
.test_template: &test_template_def
stage: test
script:
- echo "Start building..."
- pip install pillow==6.2.2
- pip install pillow==6.2.1
- pip install -e .
- python -c "import mmcls; print(mmcls.__version__)"
- echo "Start testing..."
- pip install pytest coverage
- coverage run --source mmcls -m pytest tests/
- coverage report -m
test:pytorch1.3-cuda10:
image: $PYTORCH_IMAGE
<<: *test_template_def
test:pat0.6.0dev-cuda9:
image: $PARROTS_IMAGE
<<: *test_template_def

View File

@ -3,6 +3,7 @@ import copy
import torch
import torch.nn as nn
from mmcv.runner import OptimizerHook
from mmcv.utils.parrots_wrapper import _BatchNorm
from ..utils import allreduce_grads
from .utils import cast_tensor_type
@ -95,7 +96,7 @@ def wrap_fp16_model(model):
def patch_norm_fp32(module):
if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
if isinstance(module, (_BatchNorm, nn.GroupNorm)):
module.float()
module.forward = patch_forward_method(module.forward, torch.half,
torch.float)

View File

@ -2,7 +2,7 @@ import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
kaiming_init)
from torch.nn.modules.batchnorm import _BatchNorm
from mmcv.utils.parrots_wrapper import _BatchNorm
from ..builder import BACKBONES
from .base_backbone import BaseBackbone

View File

@ -0,0 +1,4 @@
from .channel_shuffle import channel_shuffle
from .make_divisible import make_divisible
__all__ = ['channel_shuffle', 'make_divisible']

View File

@ -0,0 +1,28 @@
import torch
def channel_shuffle(x, groups):
"""Channel Shuffle operation.
This function enables cross-group information flow for multiple groups
convolution layers.
Args:
x (Tensor): The input tensor.
groups (int): The number of groups to divide the input tensor
in the channel dimension.
Returns:
Tensor: The output tensor after channel shuffle operation.
"""
batch_size, num_channels, height, width = x.size()
assert (num_channels % groups == 0), ('num_channels should be '
'divisible by groups')
channels_per_group = num_channels // groups
x = x.view(batch_size, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(batch_size, -1, height, width)
return x

View File

@ -0,0 +1,24 @@
def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
"""Make divisible function.
This function rounds the channel number down to the nearest value that can
be divisible by the divisor.
Args:
value (int): The original channel number.
divisor (int): The divisor to fully divide the channel number.
min_value (int, optional): The minimum value of the output channel.
Default: None, means that the minimum value equal to the divisor.
min_ratio (float, optional): The minimum ratio of the rounded channel
number to the original channel number. Default: 0.9.
Returns:
int: The modified output channel number
"""
if min_value is None:
min_value = divisor
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than (1-min_ratio).
if new_value < min_ratio * value:
new_value += divisor
return new_value

View File

@ -1,4 +1,2 @@
mmcv>=0.3.0
mmcv-nightly
numpy
torch>=1.1
torchvision

View File

@ -1,7 +1,7 @@
import pytest
import torch
from mmcv.utils.parrots_wrapper import _BatchNorm
from torch.nn.modules import AvgPool2d
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.backbones import ResNet, ResNetV1d
from mmcls.models.backbones.resnet import BasicBlock, Bottleneck, ResLayer

View File

@ -0,0 +1,37 @@
import pytest
import torch
from mmcls.models.utils import channel_shuffle, make_divisible
def test_make_divisible():
# test min_value is None
result = make_divisible(34, 8, None)
assert result == 32
# test when new_value > min_ratio * value
result = make_divisible(10, 8, min_ratio=0.9)
assert result == 16
# test min_value = 0.8
result = make_divisible(33, 8, min_ratio=0.8)
assert result == 32
def test_channel_shuffle():
x = torch.randn(1, 24, 56, 56)
with pytest.raises(AssertionError):
# num_channels should be divisible by groups
channel_shuffle(x, 7)
groups = 3
batch_size, num_channels, height, width = x.size()
channels_per_group = num_channels // groups
out = channel_shuffle(x, groups)
# test the output value when groups = 3
for b in range(batch_size):
for c in range(num_channels):
c_out = c % channels_per_group * groups + c // channels_per_group
for i in range(height):
for j in range(width):
assert x[b, c, i, j] == out[b, c_out, i, j]