Merge branch 'master' of gitlab.sz.sensetime.com:open-mmlab/mmclassification into dev_shufflenetv2
commit
cf85db0f7c
|
@ -1,4 +1,6 @@
|
|||
image: registry.sensetime.com/eig-research/pytorch:1.3.1-cuda10.1-cudnn7-devel
|
||||
variables:
|
||||
PARROTS_IMAGE: registry.sensetime.com/platform/product:pat20200612
|
||||
PYTORCH_IMAGE: registry.sensetime.com/eig-research/pytorch:1.3.1-cuda10.1-cudnn7-devel
|
||||
|
||||
stages:
|
||||
- linting
|
||||
|
@ -14,6 +16,7 @@ before_script:
|
|||
- python -c "import torch; print(torch.__version__)"
|
||||
|
||||
linting:
|
||||
image: $PYTORCH_IMAGE
|
||||
stage: linting
|
||||
script:
|
||||
- pip install flake8 yapf isort
|
||||
|
@ -21,14 +24,22 @@ linting:
|
|||
- isort -rc --check-only --diff mmcls/ tools/ tests/
|
||||
- yapf -r -d mmcls/ tools/ tests/ configs/
|
||||
|
||||
test:
|
||||
.test_template: &test_template_def
|
||||
stage: test
|
||||
script:
|
||||
- echo "Start building..."
|
||||
- pip install pillow==6.2.2
|
||||
- pip install pillow==6.2.1
|
||||
- pip install -e .
|
||||
- python -c "import mmcls; print(mmcls.__version__)"
|
||||
- echo "Start testing..."
|
||||
- pip install pytest coverage
|
||||
- coverage run --source mmcls -m pytest tests/
|
||||
- coverage report -m
|
||||
|
||||
test:pytorch1.3-cuda10:
|
||||
image: $PYTORCH_IMAGE
|
||||
<<: *test_template_def
|
||||
|
||||
test:pat0.6.0dev-cuda9:
|
||||
image: $PARROTS_IMAGE
|
||||
<<: *test_template_def
|
|
@ -3,6 +3,7 @@ import copy
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import OptimizerHook
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from ..utils import allreduce_grads
|
||||
from .utils import cast_tensor_type
|
||||
|
@ -95,7 +96,7 @@ def wrap_fp16_model(model):
|
|||
|
||||
|
||||
def patch_norm_fp32(module):
|
||||
if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
|
||||
if isinstance(module, (_BatchNorm, nn.GroupNorm)):
|
||||
module.float()
|
||||
module.forward = patch_forward_method(module.forward, torch.half,
|
||||
torch.float)
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch.nn as nn
|
|||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
|
||||
kaiming_init)
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from .base_backbone import BaseBackbone
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
from .channel_shuffle import channel_shuffle
|
||||
from .make_divisible import make_divisible
|
||||
|
||||
__all__ = ['channel_shuffle', 'make_divisible']
|
|
@ -0,0 +1,28 @@
|
|||
import torch
|
||||
|
||||
|
||||
def channel_shuffle(x, groups):
|
||||
"""Channel Shuffle operation.
|
||||
|
||||
This function enables cross-group information flow for multiple groups
|
||||
convolution layers.
|
||||
|
||||
Args:
|
||||
x (Tensor): The input tensor.
|
||||
groups (int): The number of groups to divide the input tensor
|
||||
in the channel dimension.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor after channel shuffle operation.
|
||||
"""
|
||||
|
||||
batch_size, num_channels, height, width = x.size()
|
||||
assert (num_channels % groups == 0), ('num_channels should be '
|
||||
'divisible by groups')
|
||||
channels_per_group = num_channels // groups
|
||||
|
||||
x = x.view(batch_size, groups, channels_per_group, height, width)
|
||||
x = torch.transpose(x, 1, 2).contiguous()
|
||||
x = x.view(batch_size, -1, height, width)
|
||||
|
||||
return x
|
|
@ -0,0 +1,24 @@
|
|||
def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
|
||||
"""Make divisible function.
|
||||
|
||||
This function rounds the channel number down to the nearest value that can
|
||||
be divisible by the divisor.
|
||||
|
||||
Args:
|
||||
value (int): The original channel number.
|
||||
divisor (int): The divisor to fully divide the channel number.
|
||||
min_value (int, optional): The minimum value of the output channel.
|
||||
Default: None, means that the minimum value equal to the divisor.
|
||||
min_ratio (float, optional): The minimum ratio of the rounded channel
|
||||
number to the original channel number. Default: 0.9.
|
||||
Returns:
|
||||
int: The modified output channel number
|
||||
"""
|
||||
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than (1-min_ratio).
|
||||
if new_value < min_ratio * value:
|
||||
new_value += divisor
|
||||
return new_value
|
|
@ -1,4 +1,2 @@
|
|||
mmcv>=0.3.0
|
||||
mmcv-nightly
|
||||
numpy
|
||||
torch>=1.1
|
||||
torchvision
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from torch.nn.modules import AvgPool2d
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import ResNet, ResNetV1d
|
||||
from mmcls.models.backbones.resnet import BasicBlock, Bottleneck, ResLayer
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcls.models.utils import channel_shuffle, make_divisible
|
||||
|
||||
|
||||
def test_make_divisible():
|
||||
# test min_value is None
|
||||
result = make_divisible(34, 8, None)
|
||||
assert result == 32
|
||||
|
||||
# test when new_value > min_ratio * value
|
||||
result = make_divisible(10, 8, min_ratio=0.9)
|
||||
assert result == 16
|
||||
|
||||
# test min_value = 0.8
|
||||
result = make_divisible(33, 8, min_ratio=0.8)
|
||||
assert result == 32
|
||||
|
||||
|
||||
def test_channel_shuffle():
|
||||
x = torch.randn(1, 24, 56, 56)
|
||||
with pytest.raises(AssertionError):
|
||||
# num_channels should be divisible by groups
|
||||
channel_shuffle(x, 7)
|
||||
|
||||
groups = 3
|
||||
batch_size, num_channels, height, width = x.size()
|
||||
channels_per_group = num_channels // groups
|
||||
out = channel_shuffle(x, groups)
|
||||
# test the output value when groups = 3
|
||||
for b in range(batch_size):
|
||||
for c in range(num_channels):
|
||||
c_out = c % channels_per_group * groups + c // channels_per_group
|
||||
for i in range(height):
|
||||
for j in range(width):
|
||||
assert x[b, c, i, j] == out[b, c_out, i, j]
|
Loading…
Reference in New Issue