From 29975930f935145668c0c19dee0f23d30ef86d5f Mon Sep 17 00:00:00 2001 From: lixiaojie Date: Mon, 15 Jun 2020 16:42:15 +0800 Subject: [PATCH 1/2] Dev/backbone utils --- mmcls/models/utils/__init__.py | 4 +++ mmcls/models/utils/channel_shuffle.py | 28 ++++++++++++++++++++ mmcls/models/utils/make_divisible.py | 24 +++++++++++++++++ tests/test_backbones/test_utils.py | 37 +++++++++++++++++++++++++++ 4 files changed, 93 insertions(+) create mode 100644 mmcls/models/utils/__init__.py create mode 100644 mmcls/models/utils/channel_shuffle.py create mode 100644 mmcls/models/utils/make_divisible.py create mode 100644 tests/test_backbones/test_utils.py diff --git a/mmcls/models/utils/__init__.py b/mmcls/models/utils/__init__.py new file mode 100644 index 00000000..2b7fc302 --- /dev/null +++ b/mmcls/models/utils/__init__.py @@ -0,0 +1,4 @@ +from .channel_shuffle import channel_shuffle +from .make_divisible import make_divisible + +__all__ = ['channel_shuffle', 'make_divisible'] diff --git a/mmcls/models/utils/channel_shuffle.py b/mmcls/models/utils/channel_shuffle.py new file mode 100644 index 00000000..51d6d98c --- /dev/null +++ b/mmcls/models/utils/channel_shuffle.py @@ -0,0 +1,28 @@ +import torch + + +def channel_shuffle(x, groups): + """Channel Shuffle operation. + + This function enables cross-group information flow for multiple groups + convolution layers. + + Args: + x (Tensor): The input tensor. + groups (int): The number of groups to divide the input tensor + in the channel dimension. + + Returns: + Tensor: The output tensor after channel shuffle operation. + """ + + batch_size, num_channels, height, width = x.size() + assert (num_channels % groups == 0), ('num_channels should be ' + 'divisible by groups') + channels_per_group = num_channels // groups + + x = x.view(batch_size, groups, channels_per_group, height, width) + x = torch.transpose(x, 1, 2).contiguous() + x = x.view(batch_size, -1, height, width) + + return x diff --git a/mmcls/models/utils/make_divisible.py b/mmcls/models/utils/make_divisible.py new file mode 100644 index 00000000..02ee047c --- /dev/null +++ b/mmcls/models/utils/make_divisible.py @@ -0,0 +1,24 @@ +def make_divisible(value, divisor, min_value=None, min_ratio=0.9): + """Make divisible function. + + This function rounds the channel number down to the nearest value that can + be divisible by the divisor. + + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int, optional): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float, optional): The minimum ratio of the rounded channel + number to the original channel number. Default: 0.9. + Returns: + int: The modified output channel number + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value diff --git a/tests/test_backbones/test_utils.py b/tests/test_backbones/test_utils.py new file mode 100644 index 00000000..deee9a56 --- /dev/null +++ b/tests/test_backbones/test_utils.py @@ -0,0 +1,37 @@ +import pytest +import torch + +from mmcls.models.utils import channel_shuffle, make_divisible + + +def test_make_divisible(): + # test min_value is None + result = make_divisible(34, 8, None) + assert result == 32 + + # test when new_value > min_ratio * value + result = make_divisible(10, 8, min_ratio=0.9) + assert result == 16 + + # test min_value = 0.8 + result = make_divisible(33, 8, min_ratio=0.8) + assert result == 32 + + +def test_channel_shuffle(): + x = torch.randn(1, 24, 56, 56) + with pytest.raises(AssertionError): + # num_channels should be divisible by groups + channel_shuffle(x, 7) + + groups = 3 + batch_size, num_channels, height, width = x.size() + channels_per_group = num_channels // groups + out = channel_shuffle(x, groups) + # test the output value when groups = 3 + for b in range(batch_size): + for c in range(num_channels): + c_out = c % channels_per_group * groups + c // channels_per_group + for i in range(height): + for j in range(width): + assert x[b, c, i, j] == out[b, c_out, i, j] From a1da2013ad07bd66f2b5b0b24d2d9c1a16b11902 Mon Sep 17 00:00:00 2001 From: wangshiguang Date: Mon, 15 Jun 2020 17:43:40 +0800 Subject: [PATCH 2/2] add pat ci image --- .gitlab-ci.yml | 17 ++++++++++++++--- mmcls/core/fp16/hooks.py | 3 ++- mmcls/models/backbones/resnet.py | 2 +- requirements.txt | 4 +--- tests/test_backbones/test_resnet.py | 2 +- 5 files changed, 19 insertions(+), 9 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 0ae0e1a1..fb9ac45f 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,4 +1,6 @@ -image: registry.sensetime.com/eig-research/pytorch:1.3.1-cuda10.1-cudnn7-devel +variables: + PARROTS_IMAGE: registry.sensetime.com/platform/product:pat20200612 + PYTORCH_IMAGE: registry.sensetime.com/eig-research/pytorch:1.3.1-cuda10.1-cudnn7-devel stages: - linting @@ -14,6 +16,7 @@ before_script: - python -c "import torch; print(torch.__version__)" linting: + image: $PYTORCH_IMAGE stage: linting script: - pip install flake8 yapf isort @@ -21,14 +24,22 @@ linting: - isort -rc --check-only --diff mmcls/ tools/ tests/ - yapf -r -d mmcls/ tools/ tests/ configs/ -test: +.test_template: &test_template_def stage: test script: - echo "Start building..." - - pip install pillow==6.2.2 + - pip install pillow==6.2.1 - pip install -e . - python -c "import mmcls; print(mmcls.__version__)" - echo "Start testing..." - pip install pytest coverage - coverage run --source mmcls -m pytest tests/ - coverage report -m + +test:pytorch1.3-cuda10: + image: $PYTORCH_IMAGE + <<: *test_template_def + +test:pat0.6.0dev-cuda9: + image: $PARROTS_IMAGE + <<: *test_template_def \ No newline at end of file diff --git a/mmcls/core/fp16/hooks.py b/mmcls/core/fp16/hooks.py index c3d4e098..565ac3a7 100644 --- a/mmcls/core/fp16/hooks.py +++ b/mmcls/core/fp16/hooks.py @@ -3,6 +3,7 @@ import copy import torch import torch.nn as nn from mmcv.runner import OptimizerHook +from mmcv.utils.parrots_wrapper import _BatchNorm from ..utils import allreduce_grads from .utils import cast_tensor_type @@ -95,7 +96,7 @@ def wrap_fp16_model(model): def patch_norm_fp32(module): - if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)): + if isinstance(module, (_BatchNorm, nn.GroupNorm)): module.float() module.forward = patch_forward_method(module.forward, torch.half, torch.float) diff --git a/mmcls/models/backbones/resnet.py b/mmcls/models/backbones/resnet.py index 6830379e..5e433da4 100644 --- a/mmcls/models/backbones/resnet.py +++ b/mmcls/models/backbones/resnet.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch.utils.checkpoint as cp from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init, kaiming_init) -from torch.nn.modules.batchnorm import _BatchNorm +from mmcv.utils.parrots_wrapper import _BatchNorm from ..builder import BACKBONES from .base_backbone import BaseBackbone diff --git a/requirements.txt b/requirements.txt index e4ff4f78..ed87c620 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,2 @@ -mmcv>=0.3.0 +mmcv-nightly numpy -torch>=1.1 -torchvision diff --git a/tests/test_backbones/test_resnet.py b/tests/test_backbones/test_resnet.py index b2d51001..b52692cc 100644 --- a/tests/test_backbones/test_resnet.py +++ b/tests/test_backbones/test_resnet.py @@ -1,7 +1,7 @@ import pytest import torch +from mmcv.utils.parrots_wrapper import _BatchNorm from torch.nn.modules import AvgPool2d -from torch.nn.modules.batchnorm import _BatchNorm from mmcls.models.backbones import ResNet, ResNetV1d from mmcls.models.backbones.resnet import BasicBlock, Bottleneck, ResLayer