From 16f1f77b41484c21e349910b81c31e391d099f41 Mon Sep 17 00:00:00 2001 From: Mike Date: Wed, 6 May 2020 23:21:50 -0400 Subject: [PATCH 01/13] Add a test workflow for github actions --- .github/workflows/tests.yml | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 .github/workflows/tests.yml diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..2d75edff --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,31 @@ +name: Python tests + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + test: + name: Run tests on ${{ matrix.os }} with Python ${{ matrix.python }} + strategy: + matrix: + os: [ubuntu-latest, macOS-latest] + python: ['3.8', '3.7'] + runs-on: ${{ matrix.os }} + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Run tests + run: | + pytest From 8da43e06171c5d796dc4b518412e6ea146dee421 Mon Sep 17 00:00:00 2001 From: michal Date: Thu, 7 May 2020 00:20:34 -0400 Subject: [PATCH 02/13] Install extra dependencies required by some models and log test durations --- .github/workflows/tests.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2d75edff..d035f3b2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -24,8 +24,10 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pytest + pip install pytest pytest-timeout if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + pip install scipy + pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11 - name: Run tests run: | - pytest + pytest -vv --durations=0 ./tests From 69d725c9fe92efed62e954d31e95653fb091a797 Mon Sep 17 00:00:00 2001 From: michal Date: Thu, 7 May 2020 00:20:58 -0400 Subject: [PATCH 03/13] Basic forward pass test for all registered models --- tests/__init__.py | 0 tests/test_inference.py | 19 +++++++++++++++++++ 2 files changed, 19 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/test_inference.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_inference.py b/tests/test_inference.py new file mode 100644 index 00000000..75b8d445 --- /dev/null +++ b/tests/test_inference.py @@ -0,0 +1,19 @@ +import pytest +import torch + +from timm import list_models, create_model + + +@pytest.mark.timeout(60) +@pytest.mark.parametrize('model_name', list_models()) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_forward(model_name, batch_size): + """Run a single forward pass with each model""" + model = create_model(model_name, pretrained=False) + model.eval() + + inputs = torch.randn((batch_size, *model.default_cfg['input_size'])) + outputs = model(inputs) + + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs' From 8c77f14cae7d974abb35a5c9f206274da49970ed Mon Sep 17 00:00:00 2001 From: michal Date: Thu, 7 May 2020 01:09:16 -0400 Subject: [PATCH 04/13] Install cpu version of torch on ubuntu --- .github/workflows/tests.yml | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d035f3b2..68fa4741 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -12,7 +12,9 @@ jobs: strategy: matrix: os: [ubuntu-latest, macOS-latest] - python: ['3.8', '3.7'] + python: ['3.8'] + torch: ['1.5.0'] + torchvision: ['0.6.0'] runs-on: ${{ matrix.os }} steps: @@ -21,10 +23,18 @@ jobs: uses: actions/setup-python@v1 with: python-version: ${{ matrix.python }} - - name: Install dependencies + - name: Install testing dependencies run: | python -m pip install --upgrade pip pip install pytest pytest-timeout + - name: Install torch on mac + if: startsWith(matrix.os, 'macOS') + run: pip install torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }} + - name: Install torch on ubuntu + if: startsWith(matrix.os, 'ubuntu') + run: pip install torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html + - name: Install requirements + run: | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi pip install scipy pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11 From 305a2db70566bb991a1465c41f4c298a8bcb8a77 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 May 2020 10:14:24 -0700 Subject: [PATCH 05/13] Update test_inference.py Make the timeout 5-min for now, see if we can get a pass... --- tests/test_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index 75b8d445..34bac63d 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -4,7 +4,7 @@ import torch from timm import list_models, create_model -@pytest.mark.timeout(60) +@pytest.mark.timeout(360) @pytest.mark.parametrize('model_name', list_models()) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward(model_name, batch_size): From e545bb9401517b0a3202f916315e40938324d093 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 May 2020 10:15:49 -0700 Subject: [PATCH 06/13] Update test_inference.py Not six min --- tests/test_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index 34bac63d..dc45c409 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -4,7 +4,7 @@ import torch from timm import list_models, create_model -@pytest.mark.timeout(360) +@pytest.mark.timeout(300) @pytest.mark.parametrize('model_name', list_models()) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward(model_name, batch_size): From f4cdc2ac319a3d76db2986e7179d99715db14e8c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 11 May 2020 23:27:09 -0700 Subject: [PATCH 07/13] Add ResNeSt models --- README.md | 1 + timm/models/__init__.py | 1 + timm/models/layers/split_attn.py | 83 ++++++++++++ timm/models/resnest.py | 214 +++++++++++++++++++++++++++++++ 4 files changed, 299 insertions(+) create mode 100644 timm/models/layers/split_attn.py create mode 100644 timm/models/resnest.py diff --git a/README.md b/README.md index 70e2a701..ac6b57ce 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,7 @@ Included models: * Instagram trained / ImageNet tuned ResNeXt101-32x8d to 32x48d from from [facebookresearch](https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/) * Res2Net (https://github.com/gasvn/Res2Net, https://arxiv.org/abs/1904.01169) * Selective Kernel (SK) Nets (https://arxiv.org/abs/1903.06586) + * ResNeSt (code adapted from https://github.com/zhanghang1989/ResNeSt, paper https://arxiv.org/abs/2004.08955) * DLA * Original (https://github.com/ucbdrive/dla, https://arxiv.org/abs/1707.06484) * Res2Net (https://github.com/gasvn/Res2Net, https://arxiv.org/abs/1904.01169) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index b073eb3a..d421ad45 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -18,6 +18,7 @@ from .dla import * from .hrnet import * from .sknet import * from .tresnet import * +from .resnest import * from .registry import * from .factory import create_model diff --git a/timm/models/layers/split_attn.py b/timm/models/layers/split_attn.py new file mode 100644 index 00000000..f91892de --- /dev/null +++ b/timm/models/layers/split_attn.py @@ -0,0 +1,83 @@ +""" Split Attention Conv2d (for ResNeSt Models) + +Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 + +Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt + +Modified for torchscript compat, performance, and consistency with timm by Ross Wightman +""" +import torch +import torch.nn.functional as F +from torch import nn + + +class RadixSoftmax(nn.Module): + def __init__(self, radix, cardinality): + super(RadixSoftmax, self).__init__() + self.radix = radix + self.cardinality = cardinality + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttnConv2d(nn.Module): + """Split-Attention Conv2d + """ + def __init__(self, in_channels, channels, kernel_size, stride=1, padding=0, + dilation=1, groups=1, bias=False, radix=2, reduction_factor=4, + act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): + super(SplitAttnConv2d, self).__init__() + self.radix = radix + self.cardinality = groups + self.channels = channels + mid_chs = channels * radix + attn_chs = max(in_channels * radix // reduction_factor, 32) + self.conv = nn.Conv2d( + in_channels, mid_chs, kernel_size, stride, padding, dilation, + groups=groups * radix, bias=bias, **kwargs) + self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None + self.act0 = act_layer(inplace=True) + self.fc1 = nn.Conv2d(channels, attn_chs, 1, groups=self.cardinality) + self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None + self.act1 = act_layer(inplace=True) + self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=self.cardinality) + self.drop_block = drop_block + self.rsoftmax = RadixSoftmax(radix, groups) + + def forward(self, x): + x = self.conv(x) + if self.bn0 is not None: + x = self.bn0(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act0(x) + + B, RC, H, W = x.shape + if self.radix > 1: + x = x.reshape((B, self.radix, RC // self.radix, H, W)) + x_gap = torch.sum(x, dim=1) + else: + x_gap = x + x_gap = F.adaptive_avg_pool2d(x_gap, 1) + x_gap = self.fc1(x_gap) + + if self.bn1 is not None: + x_gap = self.bn1(x_gap) + x_gap = self.act1(x_gap) + + x_attn = self.fc2(x_gap) + x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) + + if self.radix > 1: + out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) + else: + out = x * x_attn + return out.contiguous() diff --git a/timm/models/resnest.py b/timm/models/resnest.py new file mode 100644 index 00000000..e4f0157b --- /dev/null +++ b/timm/models/resnest.py @@ -0,0 +1,214 @@ +""" ResNeSt Models + +Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 + +Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang1989/ResNeSt + +Modified for torchscript compat, and consistency with timm by Ross Wightman +""" +import math +import torch +import torch.nn.functional as F +from torch import nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import DropBlock2d +from .helpers import load_pretrained +from .layers import SelectiveKernelConv, ConvBnAct, create_attn +from .layers.split_attn import SplitAttnConv2d +from .registry import register_model +from .resnet import ResNet + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + +default_cfgs = { + 'resnest26d': _cfg( + url=''), + 'resnest50d': _cfg( + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50-528c19ca.pth'), + 'resnest101e': _cfg( + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest101-22405ba7.pth', input_size=(3, 256, 256)), + 'resnest200e': _cfg( + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest200-75117900.pth', input_size=(3, 320, 320)), + 'resnest269e': _cfg( + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest269-0cc87c48.pth', input_size=(3, 416, 416)), +} + + +class ResNestBottleneck(nn.Module): + """ResNet Bottleneck + """ + # pylint: disable=unused-argument + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(ResNestBottleneck, self).__init__() + assert reduce_first == 1 # not supported + assert attn_layer is None # not supported + assert aa_layer is None # TODO not yet supported + assert drop_path is None # TODO not yet supported + + group_width = int(planes * (base_width / 64.)) * cardinality + first_dilation = first_dilation or dilation + if avd and (stride > 1 or is_first): + avd_stride = stride + stride = 1 + else: + avd_stride = 0 + self.radix = radix + + self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) + self.bn1 = norm_layer(group_width) + self.drop_block1 = drop_block if drop_block is not None else None + self.act1 = act_layer(inplace=True) + self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None + + if self.radix >= 1: + self.conv2 = SplitAttnConv2d( + group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, norm_layer=norm_layer, drop_block=drop_block) + self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness + self.drop_block2 = None + self.act2 = None + else: + self.conv2 = nn.Conv2d( + group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, bias=False) + self.bn2 = norm_layer(group_width) + self.drop_block2 = drop_block if drop_block is not None else None + self.act2 = act_layer(inplace=True) + self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None + + self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False) + self.bn3 = norm_layer(planes*4) + self.drop_block3 = drop_block if drop_block is not None else None + self.act3 = act_layer(inplace=True) + self.downsample = downsample + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + if self.drop_block1 is not None: + out = self.drop_block1(out) + out = self.act1(out) + + if self.avd_first is not None: + out = self.avd_first(out) + + out = self.conv2(out) + if self.bn2 is not None: + out = self.bn2(out) + if self.drop_block2 is not None: + out = self.drop_block2(out) + out = self.act2(out) + + if self.avd_last is not None: + out = self.avd_last(out) + + out = self.conv3(out) + out = self.bn3(out) + if self.drop_block3 is not None: + out = self.drop_block3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act3(out) + return out + + +@register_model +def resnest26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ ResNeSt-26d model. + """ + default_cfg = default_cfgs['resnest26d'] + model = ResNet( + ResNestBottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def resnest50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ ResNeSt-50d model. Matches paper ResNeSt-50 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'd' for deep stem, stem_width 32, avg in downsample. + """ + default_cfg = default_cfgs['resnest50d'] + model = ResNet( + ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def resnest101e(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ ResNeSt-101e model. Matches paper ResNeSt-101 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + default_cfg = default_cfgs['resnest101e'] + model = ResNet( + ResNestBottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def resnest200e(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ ResNeSt-200e model. Matches paper ResNeSt-200 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + default_cfg = default_cfgs['resnest200e'] + model = ResNet( + ResNestBottleneck, [3, 24, 36, 3], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def resnest269e(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ ResNeSt-269e model. Matches paper ResNeSt-269 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + default_cfg = default_cfgs['resnest269e'] + model = ResNet( + ResNestBottleneck, [3, 30, 48, 8], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model From 2f884a0ce5e7ff3708eb981159d1f54464f4c52d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 12 May 2020 12:21:52 -0700 Subject: [PATCH 08/13] Add resnest14, resnest26, and two of the abalation grouped resnest50 models --- timm/models/resnest.py | 46 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/timm/models/resnest.py b/timm/models/resnest.py index e4f0157b..b68162d7 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -1,8 +1,8 @@ """ ResNeSt Models -Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 +Paper: `ResNeSt: Split-Attention Networks` - https://arxiv.org/abs/2004.08955 -Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang1989/ResNeSt +Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang1989/ResNeSt by Hang Zhang Modified for torchscript compat, and consistency with timm by Ross Wightman """ @@ -31,8 +31,10 @@ def _cfg(url='', **kwargs): } default_cfgs = { + 'resnest14d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth'), 'resnest26d': _cfg( - url=''), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth'), 'resnest50d': _cfg( url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50-528c19ca.pth'), 'resnest101e': _cfg( @@ -41,6 +43,12 @@ default_cfgs = { url='https://hangzh.s3.amazonaws.com/encoding/models/resnest200-75117900.pth', input_size=(3, 320, 320)), 'resnest269e': _cfg( url='https://hangzh.s3.amazonaws.com/encoding/models/resnest269-0cc87c48.pth', input_size=(3, 416, 416)), + 'resnest50d_4s2x40d': _cfg( + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50_fast_4s2x40d-41d14ed0.pth', + interpolation='bicubic'), + 'resnest50d_1s4x24d': _cfg( + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50_fast_1s4x24d-d4a4f76f.pth', + interpolation='bicubic') } @@ -78,7 +86,7 @@ class ResNestBottleneck(nn.Module): if self.radix >= 1: self.conv2 = SplitAttnConv2d( group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, - dilation=first_dilation, groups=cardinality, norm_layer=norm_layer, drop_block=drop_block) + dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block) self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness self.drop_block2 = None self.act2 = None @@ -135,9 +143,24 @@ class ResNestBottleneck(nn.Module): return out +@register_model +def resnest14d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ ResNeSt-14d model. Weights ported from GluonCV. + """ + default_cfg = default_cfgs['resnest14d'] + model = ResNet( + ResNestBottleneck, [1, 1, 1, 1], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + @register_model def resnest26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """ ResNeSt-26d model. + """ ResNeSt-26d model. Weights ported from GluonCV. """ default_cfg = default_cfgs['resnest26d'] model = ResNet( @@ -212,3 +235,16 @@ def resnest269e(pretrained=False, num_classes=1000, in_chans=3, **kwargs): if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model + + +@register_model +def resnest50d_1s4x24d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + default_cfg = default_cfgs['resnest50d_1s4x24d'] + model = ResNet( + ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4, + block_args=dict(radix=1, avd=True, avd_first=True), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model From 9cc289f18c80100c8630808cf0842f4eb03f0b5d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 12 May 2020 13:07:03 -0700 Subject: [PATCH 09/13] Exclude EfficientNet-L2 models from test --- tests/test_inference.py | 2 +- timm/models/registry.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index dc45c409..55bafb21 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -5,7 +5,7 @@ from timm import list_models, create_model @pytest.mark.timeout(300) -@pytest.mark.parametrize('model_name', list_models()) +@pytest.mark.parametrize('model_name', list_models(exclude_filters='*efficientnet_l2*')) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward(model_name, batch_size): """Run a single forward pass with each model""" diff --git a/timm/models/registry.py b/timm/models/registry.py index c15f5414..2b8a3717 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -42,12 +42,14 @@ def _natural_key(string_): return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] -def list_models(filter='', module='', pretrained=False): +def list_models(filter='', module='', pretrained=False, exclude_filters=''): """ Return list of available model names, sorted alphabetically Args: filter (str) - Wildcard filter string that works with fnmatch module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') + pretrained (bool) - Include only models with pretrained weights if True + exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter Example: model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' @@ -58,7 +60,14 @@ def list_models(filter='', module='', pretrained=False): else: models = _model_entrypoints.keys() if filter: - models = fnmatch.filter(models, filter) + models = fnmatch.filter(models, filter) # include these models + if exclude_filters: + if not isinstance(exclude_filters, list): + exclude_filters = [exclude_filters] + for xf in exclude_filters: + exclude_models = fnmatch.filter(models, xf) # exclude these models + if len(exclude_models): + models = set(models).difference(exclude_models) if pretrained: models = _model_has_pretrained.intersection(models) return list(sorted(models, key=_natural_key)) From 5bd1ad13e714f3b4111e9356ab87f232f8fce863 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 12 May 2020 13:07:46 -0700 Subject: [PATCH 10/13] Refactor test indent --- tests/test_inference.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index 55bafb21..2490a0bc 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -8,12 +8,12 @@ from timm import list_models, create_model @pytest.mark.parametrize('model_name', list_models(exclude_filters='*efficientnet_l2*')) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward(model_name, batch_size): - """Run a single forward pass with each model""" - model = create_model(model_name, pretrained=False) - model.eval() + """Run a single forward pass with each model""" + model = create_model(model_name, pretrained=False) + model.eval() - inputs = torch.randn((batch_size, *model.default_cfg['input_size'])) - outputs = model(inputs) + inputs = torch.randn((batch_size, *model.default_cfg['input_size'])) + outputs = model(inputs) - assert outputs.shape[0] == batch_size - assert not torch.isnan(outputs).any(), 'Output included NaNs' + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs' From 208e7912f7441b8b2e21a8a8c295793e60cd6394 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 12 May 2020 13:36:31 -0700 Subject: [PATCH 11/13] Missed one of the abalation model entrypoints, update README --- README.md | 38 +++----------------------------- timm/models/layers/split_attn.py | 4 +--- timm/models/resnest.py | 17 ++++++++++++++ 3 files changed, 21 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index ac6b57ce..3cd5c223 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,9 @@ ## What's New +### May 12, 2020 +* Add ResNeSt models (code adapted from https://github.com/zhanghang1989/ResNeSt, paper https://arxiv.org/abs/2004.08955)) + ### May 3, 2020 * Pruned EfficientNet B1, B2, and B3 (https://arxiv.org/abs/2002.08258) contributed by [Yonathan Aflalo](https://github.com/yoniaflalo) @@ -70,41 +73,6 @@ * Add RandAugment trained EfficientNet-B0 weight with 77.7 top-1. Trained by [Michael Klachko](https://github.com/michaelklachko) with this code and recent hparams (see Training section) * Add `avg_checkpoints.py` script for post training weight averaging and update all scripts with header docstrings and shebangs. -### Dec 30, 2019 -* Merge [Dushyant Mehta's](https://github.com/mehtadushy) PR for SelecSLS (Selective Short and Long Range Skip Connections) networks. Good GPU memory consumption and throughput. Original: https://github.com/mehtadushy/SelecSLS-Pytorch - -### Dec 28, 2019 -* Add new model weights and training hparams (see Training Hparams section) - * `efficientnet_b3` - 81.5 top-1, 95.7 top-5 at default res/crop, 81.9, 95.8 at 320x320 1.0 crop-pct - * trained with RandAugment, ended up with an interesting but less than perfect result (see training section) - * `seresnext26d_32x4d`- 77.6 top-1, 93.6 top-5 - * deep stem (32, 32, 64), avgpool downsample - * stem/dowsample from bag-of-tricks paper - * `seresnext26t_32x4d`- 78.0 top-1, 93.7 top-5 - * deep tiered stem (24, 48, 64), avgpool downsample (a modified 'D' variant) - * stem sizing mods from Jeremy Howard and fastai devs discussing ResNet architecture experiments - -### Dec 23, 2019 -* Add RandAugment trained MixNet-XL weights with 80.48 top-1. -* `--dist-bn` argument added to train.py, will distribute BN stats between nodes after each train epoch, before eval - -### Dec 4, 2019 -* Added weights from the first training from scratch of an EfficientNet (B2) with my new RandAugment implementation. Much better than my previous B2 and very close to the official AdvProp ones (80.4 top-1, 95.08 top-5). - -### Nov 29, 2019 -* Brought EfficientNet and MobileNetV3 up to date with my https://github.com/rwightman/gen-efficientnet-pytorch code. Torchscript and ONNX export compat excluded. - * AdvProp weights added - * Official TF MobileNetv3 weights added -* EfficientNet and MobileNetV3 hook based 'feature extraction' classes added. Will serve as basis for using models as backbones in obj detection/segmentation tasks. Lots more to be done here... -* HRNet classification models and weights added from https://github.com/HRNet/HRNet-Image-Classification -* Consistency in global pooling, `reset_classifer`, and `forward_features` across models - * `forward_features` always returns unpooled feature maps now -* Reasonable chance I broke something... let me know - -### Nov 22, 2019 -* Add ImageNet training RandAugment implementation alongside AutoAugment. PyTorch Transform compatible format, using PIL. Currently training two EfficientNet models from scratch with promising results... will update. -* `drop-connect` cmd line arg finally added to `train.py`, no need to hack model fns. Works for efficientnet/mobilenetv3 based models, ignored otherwise. - ## Introduction For each competition, personal, or freelance project involving images + Convolution Neural Networks, I build on top of an evolving collection of code and models. This repo contains a (somewhat) cleaned up and paired down iteration of that code. Hopefully it'll be of use to others. diff --git a/timm/models/layers/split_attn.py b/timm/models/layers/split_attn.py index f91892de..383c4583 100644 --- a/timm/models/layers/split_attn.py +++ b/timm/models/layers/split_attn.py @@ -68,14 +68,12 @@ class SplitAttnConv2d(nn.Module): x_gap = x x_gap = F.adaptive_avg_pool2d(x_gap, 1) x_gap = self.fc1(x_gap) - if self.bn1 is not None: x_gap = self.bn1(x_gap) x_gap = self.act1(x_gap) - x_attn = self.fc2(x_gap) - x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) + x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) if self.radix > 1: out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) else: diff --git a/timm/models/resnest.py b/timm/models/resnest.py index b68162d7..849543ba 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -237,8 +237,25 @@ def resnest269e(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def resnest50d_4s2x40d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ResNeSt-50 4s2x40d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md + """ + default_cfg = default_cfgs['resnest50d_4s2x40d'] + model = ResNet( + ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2, + block_args=dict(radix=4, avd=True, avd_first=True), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + @register_model def resnest50d_1s4x24d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ResNeSt-50 1s4x24d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md + """ default_cfg = default_cfgs['resnest50d_1s4x24d'] model = ResNet( ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, From 17270c69b91c77981c921d3016aa0c2c4bfc6ec4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 12 May 2020 21:59:34 -0700 Subject: [PATCH 12/13] Remove annoying InceptionV3 dependency on scipy and insanely slow trunc_norm init. Bring InceptionV3 code into this codebase and use upcoming torch trunch_norm_ init. --- .github/workflows/tests.yml | 1 - timm/models/inception_v3.py | 606 ++++++++++++++++++++++++++---- timm/models/layers/__init__.py | 1 + timm/models/layers/weight_init.py | 60 +++ 4 files changed, 585 insertions(+), 83 deletions(-) create mode 100644 timm/models/layers/weight_init.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 68fa4741..ef43f0a2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,7 +36,6 @@ jobs: - name: Install requirements run: | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - pip install scipy pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11 - name: Run tests run: | diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index a0ea784f..0997e024 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -1,120 +1,562 @@ -from torchvision.models import Inception3 +import torch +import torch.nn as nn +import torch.nn.functional as F + from .registry import register_model from .helpers import load_pretrained +from .layers import trunc_normal_, SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD __all__ = [] +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + default_cfgs = { # original PyTorch weights, ported from Tensorflow but modified - 'inception_v3': { - 'url': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', - 'input_size': (3, 299, 299), - 'crop_pct': 0.875, - 'interpolation': 'bicubic', - 'mean': IMAGENET_INCEPTION_MEAN, # also works well enough with resnet defaults - 'std': IMAGENET_INCEPTION_STD, # also works well enough with resnet defaults - 'num_classes': 1000, - 'first_conv': 'conv0', - 'classifier': 'fc' - }, + 'inception_v3': _cfg( + url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', + has_aux=True), # checkpoint has aux logit layer weights # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) - 'tf_inception_v3': { - 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth', - 'input_size': (3, 299, 299), - 'crop_pct': 0.875, - 'interpolation': 'bicubic', - 'mean': IMAGENET_INCEPTION_MEAN, - 'std': IMAGENET_INCEPTION_STD, - 'num_classes': 1001, - 'first_conv': 'conv0', - 'classifier': 'fc' - }, + 'tf_inception_v3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth', + num_classes=1001, has_aux=False), # my port of Tensorflow adversarially trained Inception V3 from # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz - 'adv_inception_v3': { - 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth', - 'input_size': (3, 299, 299), - 'crop_pct': 0.875, - 'interpolation': 'bicubic', - 'mean': IMAGENET_INCEPTION_MEAN, - 'std': IMAGENET_INCEPTION_STD, - 'num_classes': 1001, - 'first_conv': 'conv0', - 'classifier': 'fc' - }, + 'adv_inception_v3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth', + num_classes=1001, has_aux=False), # from gluon pretrained models, best performing in terms of accuracy/loss metrics # https://gluon-cv.mxnet.io/model_zoo/classification.html - 'gluon_inception_v3': { - 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth', - 'input_size': (3, 299, 299), - 'crop_pct': 0.875, - 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, # also works well with inception defaults - 'std': IMAGENET_DEFAULT_STD, # also works well with inception defaults - 'num_classes': 1000, - 'first_conv': 'conv0', - 'classifier': 'fc' - } + 'gluon_inception_v3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth', + mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults + std=IMAGENET_DEFAULT_STD, # also works well with inception defaults + has_aux=False, + ) } -def _assert_default_kwargs(kwargs): - # for imported models (ie torchvision) without capability to change these params, - # make sure they aren't being set to non-defaults - assert kwargs.pop('global_pool', 'avg') == 'avg' - assert kwargs.pop('drop_rate', 0.) == 0. +class InceptionV3Aux(nn.Module): + """InceptionV3 with AuxLogits + """ + + def __init__(self, inception_blocks=None, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'): + super(InceptionV3Aux, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + + if inception_blocks is None: + inception_blocks = [ + BasicConv2d, InceptionA, InceptionB, InceptionC, + InceptionD, InceptionE, InceptionAux + ] + assert len(inception_blocks) == 7 + conv_block = inception_blocks[0] + inception_a = inception_blocks[1] + inception_b = inception_blocks[2] + inception_c = inception_blocks[3] + inception_d = inception_blocks[4] + inception_e = inception_blocks[5] + inception_aux = inception_blocks[6] + + self.Conv2d_1a_3x3 = conv_block(in_chans, 32, kernel_size=3, stride=2) + self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3) + self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1) + self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1) + self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3) + self.Mixed_5b = inception_a(192, pool_features=32) + self.Mixed_5c = inception_a(256, pool_features=64) + self.Mixed_5d = inception_a(288, pool_features=64) + self.Mixed_6a = inception_b(288) + self.Mixed_6b = inception_c(768, channels_7x7=128) + self.Mixed_6c = inception_c(768, channels_7x7=160) + self.Mixed_6d = inception_c(768, channels_7x7=160) + self.Mixed_6e = inception_c(768, channels_7x7=192) + self.AuxLogits = inception_aux(768, num_classes) + self.Mixed_7a = inception_d(768) + self.Mixed_7b = inception_e(1280) + self.Mixed_7c = inception_e(2048) + + self.num_features = 2048 + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + stddev = m.stddev if hasattr(m, 'stddev') else 0.1 + trunc_normal_(m.weight, std=stddev) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward_features(self, x): + # N x 3 x 299 x 299 + x = self.Conv2d_1a_3x3(x) + # N x 32 x 149 x 149 + x = self.Conv2d_2a_3x3(x) + # N x 32 x 147 x 147 + x = self.Conv2d_2b_3x3(x) + # N x 64 x 147 x 147 + x = F.max_pool2d(x, kernel_size=3, stride=2) + # N x 64 x 73 x 73 + x = self.Conv2d_3b_1x1(x) + # N x 80 x 73 x 73 + x = self.Conv2d_4a_3x3(x) + # N x 192 x 71 x 71 + x = F.max_pool2d(x, kernel_size=3, stride=2) + # N x 192 x 35 x 35 + x = self.Mixed_5b(x) + # N x 256 x 35 x 35 + x = self.Mixed_5c(x) + # N x 288 x 35 x 35 + x = self.Mixed_5d(x) + # N x 288 x 35 x 35 + x = self.Mixed_6a(x) + # N x 768 x 17 x 17 + x = self.Mixed_6b(x) + # N x 768 x 17 x 17 + x = self.Mixed_6c(x) + # N x 768 x 17 x 17 + x = self.Mixed_6d(x) + # N x 768 x 17 x 17 + x = self.Mixed_6e(x) + # N x 768 x 17 x 17 + aux = self.AuxLogits(x) if self.training else None + # N x 768 x 17 x 17 + x = self.Mixed_7a(x) + # N x 1280 x 8 x 8 + x = self.Mixed_7b(x) + # N x 2048 x 8 x 8 + x = self.Mixed_7c(x) + # N x 2048 x 8 x 8 + return x, aux + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.num_classes = num_classes + if self.num_classes > 0: + self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + else: + self.fc = nn.Identity() + + def forward(self, x): + x, aux = self.forward_features(x) + x = self.global_pool(x).flatten(1) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x, aux + + +class InceptionV3(nn.Module): + """Inception-V3 with no AuxLogits + FIXME two class defs are redundant, but less screwing around with torchsript fussyness and inconsistent returns + """ + + def __init__(self, inception_blocks=None, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'): + super(InceptionV3, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + + if inception_blocks is None: + inception_blocks = [ + BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE] + assert len(inception_blocks) >= 6 + conv_block = inception_blocks[0] + inception_a = inception_blocks[1] + inception_b = inception_blocks[2] + inception_c = inception_blocks[3] + inception_d = inception_blocks[4] + inception_e = inception_blocks[5] + + self.Conv2d_1a_3x3 = conv_block(in_chans, 32, kernel_size=3, stride=2) + self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3) + self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1) + self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1) + self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3) + self.Mixed_5b = inception_a(192, pool_features=32) + self.Mixed_5c = inception_a(256, pool_features=64) + self.Mixed_5d = inception_a(288, pool_features=64) + self.Mixed_6a = inception_b(288) + self.Mixed_6b = inception_c(768, channels_7x7=128) + self.Mixed_6c = inception_c(768, channels_7x7=160) + self.Mixed_6d = inception_c(768, channels_7x7=160) + self.Mixed_6e = inception_c(768, channels_7x7=192) + self.Mixed_7a = inception_d(768) + self.Mixed_7b = inception_e(1280) + self.Mixed_7c = inception_e(2048) + + self.num_features = 2048 + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.fc = nn.Linear(2048, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + stddev = m.stddev if hasattr(m, 'stddev') else 0.1 + trunc_normal_(m.weight, std=stddev) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward_features(self, x): + # N x 3 x 299 x 299 + x = self.Conv2d_1a_3x3(x) + # N x 32 x 149 x 149 + x = self.Conv2d_2a_3x3(x) + # N x 32 x 147 x 147 + x = self.Conv2d_2b_3x3(x) + # N x 64 x 147 x 147 + x = F.max_pool2d(x, kernel_size=3, stride=2) + # N x 64 x 73 x 73 + x = self.Conv2d_3b_1x1(x) + # N x 80 x 73 x 73 + x = self.Conv2d_4a_3x3(x) + # N x 192 x 71 x 71 + x = F.max_pool2d(x, kernel_size=3, stride=2) + # N x 192 x 35 x 35 + x = self.Mixed_5b(x) + # N x 256 x 35 x 35 + x = self.Mixed_5c(x) + # N x 288 x 35 x 35 + x = self.Mixed_5d(x) + # N x 288 x 35 x 35 + x = self.Mixed_6a(x) + # N x 768 x 17 x 17 + x = self.Mixed_6b(x) + # N x 768 x 17 x 17 + x = self.Mixed_6c(x) + # N x 768 x 17 x 17 + x = self.Mixed_6d(x) + # N x 768 x 17 x 17 + x = self.Mixed_6e(x) + # N x 768 x 17 x 17 + x = self.Mixed_7a(x) + # N x 1280 x 8 x 8 + x = self.Mixed_7b(x) + # N x 2048 x 8 x 8 + x = self.Mixed_7c(x) + # N x 2048 x 8 x 8 + return x + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.num_classes = num_classes + if self.num_classes > 0: + self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + else: + self.fc = nn.Identity() + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x).flatten(1) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +class InceptionA(nn.Module): + + def __init__(self, in_channels, pool_features, conv_block=None): + super(InceptionA, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 64, kernel_size=1) + + self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1) + self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2) + + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1) + + self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1) + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionB(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionB, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) + + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) + + def _forward(self, x): + branch3x3 = self.branch3x3(x) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) + + outputs = [branch3x3, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionC(nn.Module): + + def __init__(self, in_channels, channels_7x7, conv_block=None): + super(InceptionC, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 192, kernel_size=1) + + c7 = channels_7x7 + self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1) + self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0)) + + self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1) + self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3)) + + self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionD(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionD, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) + self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) + + self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) + self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) + + def _forward(self, x): + branch3x3 = self.branch3x3_1(x) + branch3x3 = self.branch3x3_2(branch3x3) + + branch7x7x3 = self.branch7x7x3_1(x) + branch7x7x3 = self.branch7x7x3_2(branch7x7x3) + branch7x7x3 = self.branch7x7x3_3(branch7x7x3) + branch7x7x3 = self.branch7x7x3_4(branch7x7x3) + + branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) + outputs = [branch3x3, branch7x7x3, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionE(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionE, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 320, kernel_size=1) + + self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1) + self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + + self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1) + self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1) + self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + + self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionAux(nn.Module): + + def __init__(self, in_channels, num_classes, conv_block=None): + super(InceptionAux, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.conv0 = conv_block(in_channels, 128, kernel_size=1) + self.conv1 = conv_block(128, 768, kernel_size=5) + self.conv1.stddev = 0.01 + self.fc = nn.Linear(768, num_classes) + self.fc.stddev = 0.001 + + def forward(self, x): + # N x 768 x 17 x 17 + x = F.avg_pool2d(x, kernel_size=5, stride=3) + # N x 768 x 5 x 5 + x = self.conv0(x) + # N x 128 x 5 x 5 + x = self.conv1(x) + # N x 768 x 1 x 1 + # Adaptive average pooling + x = F.adaptive_avg_pool2d(x, (1, 1)) + # N x 768 x 1 x 1 + x = torch.flatten(x, 1) + # N x 768 + x = self.fc(x) + # N x 1000 + return x + + +class BasicConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, **kwargs): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return F.relu(x, inplace=True) + + +def _inception_v3(variant, pretrained=False, **kwargs): + default_cfg = default_cfgs[variant] + if kwargs.pop('features_only', False): + assert False, 'Not Implemented' # TODO + load_strict = False + model_kwargs.pop('num_classes', 0) + model_class = InceptionV3 + else: + aux_logits = kwargs.pop('aux_logits', False) + if aux_logits: + model_class = InceptionV3Aux + load_strict = default_cfg['has_aux'] + else: + model_class = InceptionV3 + load_strict = not default_cfg['has_aux'] + + model = model_class(**kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained( + model, + num_classes=kwargs.get('num_classes', 0), + in_chans=kwargs.get('in_chans', 3), + strict=load_strict) + return model @register_model -def inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def inception_v3(pretrained=False, **kwargs): # original PyTorch weights, ported from Tensorflow but modified - default_cfg = default_cfgs['inception_v3'] - assert in_chans == 3 - _assert_default_kwargs(kwargs) - model = Inception3(num_classes=num_classes, aux_logits=True, transform_input=False) - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - model.default_cfg = default_cfg + model = _inception_v3('inception_v3', pretrained=pretrained, **kwargs) return model @register_model -def tf_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def tf_inception_v3(pretrained=False, **kwargs): # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) - default_cfg = default_cfgs['tf_inception_v3'] - assert in_chans == 3 - _assert_default_kwargs(kwargs) - model = Inception3(num_classes=num_classes, aux_logits=False, transform_input=False) - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - model.default_cfg = default_cfg + model = _inception_v3('tf_inception_v3', pretrained=pretrained, **kwargs) return model @register_model -def adv_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def adv_inception_v3(pretrained=False, **kwargs): # my port of Tensorflow adversarially trained Inception V3 from # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz - default_cfg = default_cfgs['adv_inception_v3'] - assert in_chans == 3 - _assert_default_kwargs(kwargs) - model = Inception3(num_classes=num_classes, aux_logits=False, transform_input=False) - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - model.default_cfg = default_cfg + model = _inception_v3('adv_inception_v3', pretrained=pretrained, **kwargs) return model @register_model -def gluon_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def gluon_inception_v3(pretrained=False, **kwargs): # from gluon pretrained models, best performing in terms of accuracy/loss metrics # https://gluon-cv.mxnet.io/model_zoo/classification.html - default_cfg = default_cfgs['gluon_inception_v3'] - assert in_chans == 3 - _assert_default_kwargs(kwargs) - model = Inception3(num_classes=num_classes, aux_logits=False, transform_input=False) - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - model.default_cfg = default_cfg + model = _inception_v3('gluon_inception_v3', pretrained=pretrained, **kwargs) return model diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 4f84bb9e..667e7ea1 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -19,3 +19,4 @@ from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .anti_aliasing import AntiAliasDownsampleLayer from .space_to_depth import SpaceToDepthModule from .blur_pool import BlurPool2d +from .weight_init import trunc_normal_ diff --git a/timm/models/layers/weight_init.py b/timm/models/layers/weight_init.py new file mode 100644 index 00000000..d731029f --- /dev/null +++ b/timm/models/layers/weight_init.py @@ -0,0 +1,60 @@ +import torch +import math +import warnings + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) From 1904ed8fecdb3f37818378421350315d2abf1224 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 13 May 2020 15:17:08 -0700 Subject: [PATCH 13/13] Improve dropblock impl, add fast variant, and better AMP speed, inplace, batchwise... few ResNeSt cleanups --- timm/models/layers/drop.py | 108 ++++++++++++++++++++++++------- timm/models/layers/split_attn.py | 15 ++--- timm/models/resnest.py | 17 ++--- 3 files changed, 97 insertions(+), 43 deletions(-) diff --git a/timm/models/layers/drop.py b/timm/models/layers/drop.py index 00bed078..5f2008c0 100644 --- a/timm/models/layers/drop.py +++ b/timm/models/layers/drop.py @@ -22,44 +22,89 @@ import math def drop_block_2d( - x, drop_prob: float = 0.1, training: bool = False, block_size: int = 7, - gamma_scale: float = 1.0, drop_with_noise: bool = False): + x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, + with_noise: bool = False, inplace: bool = False, batchwise: bool = False): """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf DropBlock with an experimental gaussian noise option. This layer has been tested on a few training runs with success, but needs further validation and possibly optimization for lower runtime impact. - """ - if drop_prob == 0. or not training: - return x - _, _, height, width = x.shape - total_size = width * height - clipped_block_size = min(block_size, min(width, height)) + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) # seed_drop_rate, the gamma parameter - seed_drop_rate = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( - (width - block_size + 1) * - (height - block_size + 1)) + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) # Forces the block to be inside the feature map. - w_i, h_i = torch.meshgrid(torch.arange(width).to(x.device), torch.arange(height).to(x.device)) - valid_block = ((w_i >= clipped_block_size // 2) & (w_i < width - (clipped_block_size - 1) // 2)) & \ - ((h_i >= clipped_block_size // 2) & (h_i < height - (clipped_block_size - 1) // 2)) - valid_block = torch.reshape(valid_block, (1, 1, height, width)).float() + w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) + valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ + ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) + valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) - uniform_noise = torch.rand_like(x, dtype=torch.float32) - block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).float() + if batchwise: + # one mask for whole batch, quite a bit faster + uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) + else: + uniform_noise = torch.rand_like(x) + block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) block_mask = -F.max_pool2d( -block_mask, - kernel_size=clipped_block_size, # block_size, ??? + kernel_size=clipped_block_size, # block_size, stride=1, padding=clipped_block_size // 2) - if drop_with_noise: - normal_noise = torch.randn_like(x) - x = x * block_mask + normal_noise * (1 - block_mask) + if with_noise: + normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) + if inplace: + x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) + else: + x = x * block_mask + normal_noise * (1 - block_mask) else: - normalize_scale = block_mask.numel() / (torch.sum(block_mask) + 1e-7) - x = x * block_mask * normalize_scale + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +def drop_block_fast_2d( + x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, + gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid + block mask at edges. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + if batchwise: + # one mask for whole batch, quite a bit faster + block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma + else: + # mask per batch element + block_mask = torch.rand_like(x) < gamma + block_mask = F.max_pool2d( + block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) + if inplace: + x.mul_(1. - block_mask).add_(normal_noise * block_mask) + else: + x = x * (1. - block_mask) + normal_noise * block_mask + else: + block_mask = 1 - block_mask + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale return x @@ -70,15 +115,28 @@ class DropBlock2d(nn.Module): drop_prob=0.1, block_size=7, gamma_scale=1.0, - with_noise=False): + with_noise=False, + inplace=False, + batchwise=False, + fast=True): super(DropBlock2d, self).__init__() self.drop_prob = drop_prob self.gamma_scale = gamma_scale self.block_size = block_size self.with_noise = with_noise + self.inplace = inplace + self.batchwise = batchwise + self.fast = fast # FIXME finish comparisons of fast vs not def forward(self, x): - return drop_block_2d(x, self.drop_prob, self.training, self.block_size, self.gamma_scale, self.with_noise) + if not self.training or not self.drop_prob: + return x + if self.fast: + return drop_block_fast_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + else: + return drop_block_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) def drop_path(x, drop_prob: float = 0., training: bool = False): diff --git a/timm/models/layers/split_attn.py b/timm/models/layers/split_attn.py index 383c4583..023ab6af 100644 --- a/timm/models/layers/split_attn.py +++ b/timm/models/layers/split_attn.py @@ -31,25 +31,24 @@ class RadixSoftmax(nn.Module): class SplitAttnConv2d(nn.Module): """Split-Attention Conv2d """ - def __init__(self, in_channels, channels, kernel_size, stride=1, padding=0, + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, radix=2, reduction_factor=4, act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): super(SplitAttnConv2d, self).__init__() self.radix = radix - self.cardinality = groups - self.channels = channels - mid_chs = channels * radix + self.drop_block = drop_block + mid_chs = out_channels * radix attn_chs = max(in_channels * radix // reduction_factor, 32) + self.conv = nn.Conv2d( in_channels, mid_chs, kernel_size, stride, padding, dilation, groups=groups * radix, bias=bias, **kwargs) self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None self.act0 = act_layer(inplace=True) - self.fc1 = nn.Conv2d(channels, attn_chs, 1, groups=self.cardinality) + self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None self.act1 = act_layer(inplace=True) - self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=self.cardinality) - self.drop_block = drop_block + self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) self.rsoftmax = RadixSoftmax(radix, groups) def forward(self, x): @@ -63,7 +62,7 @@ class SplitAttnConv2d(nn.Module): B, RC, H, W = x.shape if self.radix > 1: x = x.reshape((B, self.radix, RC // self.radix, H, W)) - x_gap = torch.sum(x, dim=1) + x_gap = x.sum(dim=1) else: x_gap = x x_gap = F.adaptive_avg_pool2d(x_gap, 1) diff --git a/timm/models/resnest.py b/timm/models/resnest.py index 849543ba..33b051ef 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -76,10 +76,10 @@ class ResNestBottleneck(nn.Module): else: avd_stride = 0 self.radix = radix + self.drop_block = drop_block self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) self.bn1 = norm_layer(group_width) - self.drop_block1 = drop_block if drop_block is not None else None self.act1 = act_layer(inplace=True) self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None @@ -88,20 +88,17 @@ class ResNestBottleneck(nn.Module): group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block) self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness - self.drop_block2 = None self.act2 = None else: self.conv2 = nn.Conv2d( group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) self.bn2 = norm_layer(group_width) - self.drop_block2 = drop_block if drop_block is not None else None self.act2 = act_layer(inplace=True) self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False) self.bn3 = norm_layer(planes*4) - self.drop_block3 = drop_block if drop_block is not None else None self.act3 = act_layer(inplace=True) self.downsample = downsample @@ -113,8 +110,8 @@ class ResNestBottleneck(nn.Module): out = self.conv1(x) out = self.bn1(out) - if self.drop_block1 is not None: - out = self.drop_block1(out) + if self.drop_block is not None: + out = self.drop_block(out) out = self.act1(out) if self.avd_first is not None: @@ -123,8 +120,8 @@ class ResNestBottleneck(nn.Module): out = self.conv2(out) if self.bn2 is not None: out = self.bn2(out) - if self.drop_block2 is not None: - out = self.drop_block2(out) + if self.drop_block is not None: + out = self.drop_block(out) out = self.act2(out) if self.avd_last is not None: @@ -132,8 +129,8 @@ class ResNestBottleneck(nn.Module): out = self.conv3(out) out = self.bn3(out) - if self.drop_block3 is not None: - out = self.drop_block3(out) + if self.drop_block is not None: + out = self.drop_block(out) if self.downsample is not None: residual = self.downsample(x)