mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #668 from rwightman/more_attn
Add Gather-Excite, Global Context, BAT, Non-Local attn modules and refactored all attn modules and factory for improved consistency. EfficientNet / MobileNetV3 backbones able to use a wider variety of attention modules.
This commit is contained in:
commit
54a6cca27a
7
.github/workflows/tests.yml
vendored
7
.github/workflows/tests.yml
vendored
@ -36,11 +36,16 @@ jobs:
|
||||
run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}
|
||||
- name: Install torch on ubuntu
|
||||
if: startsWith(matrix.os, 'ubuntu')
|
||||
run: pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
run: |
|
||||
pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
sudo apt update
|
||||
sudo apt install -y google-perftools
|
||||
- name: Install requirements
|
||||
run: |
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git@v1.0.12
|
||||
- name: Run tests
|
||||
env:
|
||||
LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
|
||||
run: |
|
||||
pytest -vv --durations=0 ./tests
|
||||
|
16
README.md
16
README.md
@ -295,10 +295,24 @@ Several (less common) features that I often utilize in my projects are included.
|
||||
* SplitBachNorm - allows splitting batch norm layers between clean and augmented (auxiliary batch norm) data
|
||||
* DropPath aka "Stochastic Depth" (https://arxiv.org/abs/1603.09382)
|
||||
* DropBlock (https://arxiv.org/abs/1810.12890)
|
||||
* Efficient Channel Attention - ECA (https://arxiv.org/abs/1910.03151)
|
||||
* Blur Pooling (https://arxiv.org/abs/1904.11486)
|
||||
* Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper?
|
||||
* Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets)
|
||||
* An extensive selection of channel and/or spatial attention modules:
|
||||
* Bottleneck Transformer - https://arxiv.org/abs/2101.11605
|
||||
* CBAM - https://arxiv.org/abs/1807.06521
|
||||
* Effective Squeeze-Excitation (ESE) - https://arxiv.org/abs/1911.06667
|
||||
* Efficient Channel Attention (ECA) - https://arxiv.org/abs/1910.03151
|
||||
* Gather-Excite (GE) - https://arxiv.org/abs/1810.12348
|
||||
* Global Context (GC) - https://arxiv.org/abs/1904.11492
|
||||
* Halo - https://arxiv.org/abs/2103.12731
|
||||
* Involution - https://arxiv.org/abs/2103.06255
|
||||
* Lambda Layer - https://arxiv.org/abs/2102.08602
|
||||
* Non-Local (NL) - https://arxiv.org/abs/1711.07971
|
||||
* Squeeze-and-Excitation (SE) - https://arxiv.org/abs/1709.01507
|
||||
* Selective Kernel (SK) - (https://arxiv.org/abs/1903.06586
|
||||
* Split (SPLAT) - https://arxiv.org/abs/2004.08955
|
||||
* Shifted Window (SWIN) - https://arxiv.org/abs/2103.14030
|
||||
|
||||
## Results
|
||||
|
||||
|
@ -24,7 +24,7 @@ NUM_NON_STD = len(NON_STD_FILTERS)
|
||||
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
|
||||
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
|
||||
EXCLUDE_FILTERS = [
|
||||
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm',
|
||||
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
|
||||
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*',
|
||||
'*resnetrs350*', '*resnetrs420*']
|
||||
else:
|
||||
|
@ -17,7 +17,6 @@ from .inception_resnet_v2 import *
|
||||
from .inception_v3 import *
|
||||
from .inception_v4 import *
|
||||
from .levit import *
|
||||
#from .levit import *
|
||||
from .mlp_mixer import *
|
||||
from .mobilenetv3 import *
|
||||
from .nasnet import *
|
||||
|
@ -12,24 +12,12 @@ Consider all of the models definitions here as experimental WIP and likely to ch
|
||||
|
||||
Hacked together by / copyright Ross Wightman, 2021.
|
||||
"""
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from collections import OrderedDict
|
||||
from typing import Tuple, List, Optional, Union, Any, Callable
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .byobnet import BlocksCfg, ByobCfg, create_byob_stem, create_byob_stages, create_downsample,\
|
||||
reduce_feat_size, register_block, num_groups, LayerFn, _init_weights
|
||||
from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import ClassifierHead, ConvBnAct, DropPath, get_act_layer, convert_norm_act, get_attn, get_self_attn,\
|
||||
make_divisible, to_2tuple
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['ByoaNet']
|
||||
__all__ = []
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
@ -63,100 +51,68 @@ default_cfgs = {
|
||||
'swinnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'eca_swinnext26ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
|
||||
'rednet26t': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'rednet50ts': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'rednet26t': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'rednet50ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ByoaBlocksCfg(BlocksCfg):
|
||||
# FIXME allow overriding self_attn layer or args per block/stage,
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ByoaCfg(ByobCfg):
|
||||
blocks: Tuple[Union[ByoaBlocksCfg, Tuple[ByoaBlocksCfg, ...]], ...] = None
|
||||
self_attn_layer: Optional[str] = None
|
||||
self_attn_fixed_size: bool = False
|
||||
self_attn_kwargs: dict = field(default_factory=lambda: dict())
|
||||
|
||||
|
||||
def interleave_attn(
|
||||
types : Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs
|
||||
) -> Tuple[ByoaBlocksCfg]:
|
||||
""" interleave attn blocks
|
||||
"""
|
||||
assert len(types) == 2
|
||||
if isinstance(every, int):
|
||||
every = list(range(0 if first else every, d, every))
|
||||
if not every:
|
||||
every = [d - 1]
|
||||
set(every)
|
||||
blocks = []
|
||||
for i in range(d):
|
||||
block_type = types[1] if i in every else types[0]
|
||||
blocks += [ByoaBlocksCfg(type=block_type, d=1, **kwargs)]
|
||||
return tuple(blocks)
|
||||
|
||||
|
||||
model_cfgs = dict(
|
||||
|
||||
botnet26t=ByoaCfg(
|
||||
botnet26t=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
fixed_input_size=True,
|
||||
self_attn_layer='bottleneck',
|
||||
self_attn_fixed_size=True,
|
||||
self_attn_kwargs=dict()
|
||||
),
|
||||
botnet50ts=ByoaCfg(
|
||||
botnet50ts=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoaBlocksCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_attn(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='',
|
||||
num_features=0,
|
||||
fixed_input_size=True,
|
||||
act_layer='silu',
|
||||
self_attn_layer='bottleneck',
|
||||
self_attn_fixed_size=True,
|
||||
self_attn_kwargs=dict()
|
||||
),
|
||||
eca_botnext26ts=ByoaCfg(
|
||||
eca_botnext26ts=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25),
|
||||
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25),
|
||||
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
fixed_input_size=True,
|
||||
act_layer='silu',
|
||||
attn_layer='eca',
|
||||
self_attn_layer='bottleneck',
|
||||
self_attn_fixed_size=True,
|
||||
self_attn_kwargs=dict()
|
||||
),
|
||||
|
||||
halonet_h1=ByoaCfg(
|
||||
halonet_h1=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoaBlocksCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0),
|
||||
ByoaBlocksCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0),
|
||||
ByoaBlocksCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
|
||||
ByoaBlocksCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
|
||||
ByoBlockCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0),
|
||||
ByoBlockCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0),
|
||||
ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
|
||||
ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='7x7',
|
||||
@ -165,12 +121,12 @@ model_cfgs = dict(
|
||||
self_attn_layer='halo',
|
||||
self_attn_kwargs=dict(block_size=8, halo_size=3),
|
||||
),
|
||||
halonet_h1_c4c5=ByoaCfg(
|
||||
halonet_h1_c4c5=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoaBlocksCfg(type='bottle', d=3, c=64, s=1, gs=0, br=1.0),
|
||||
ByoaBlocksCfg(type='bottle', d=3, c=128, s=2, gs=0, br=1.0),
|
||||
ByoaBlocksCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
|
||||
ByoaBlocksCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
|
||||
ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=0, br=1.0),
|
||||
ByoBlockCfg(type='bottle', d=3, c=128, s=2, gs=0, br=1.0),
|
||||
ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
|
||||
ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
@ -179,12 +135,12 @@ model_cfgs = dict(
|
||||
self_attn_layer='halo',
|
||||
self_attn_kwargs=dict(block_size=8, halo_size=3),
|
||||
),
|
||||
halonet26t=ByoaCfg(
|
||||
halonet26t=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
@ -193,12 +149,12 @@ model_cfgs = dict(
|
||||
self_attn_layer='halo',
|
||||
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res
|
||||
),
|
||||
halonet50ts=ByoaCfg(
|
||||
halonet50ts=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_attn(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
@ -208,12 +164,12 @@ model_cfgs = dict(
|
||||
self_attn_layer='halo',
|
||||
self_attn_kwargs=dict(block_size=8, halo_size=2)
|
||||
),
|
||||
eca_halonext26ts=ByoaCfg(
|
||||
eca_halonext26ts=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
|
||||
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
|
||||
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
@ -225,12 +181,12 @@ model_cfgs = dict(
|
||||
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res
|
||||
),
|
||||
|
||||
lambda_resnet26t=ByoaCfg(
|
||||
lambda_resnet26t=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
@ -239,12 +195,12 @@ model_cfgs = dict(
|
||||
self_attn_layer='lambda',
|
||||
self_attn_kwargs=dict()
|
||||
),
|
||||
lambda_resnet50t=ByoaCfg(
|
||||
lambda_resnet50t=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_attn(types=('bottle', 'self_attn'), every=3, d=6, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=3, d=6, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
@ -253,12 +209,12 @@ model_cfgs = dict(
|
||||
self_attn_layer='lambda',
|
||||
self_attn_kwargs=dict()
|
||||
),
|
||||
eca_lambda_resnext26ts=ByoaCfg(
|
||||
eca_lambda_resnext26ts=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
|
||||
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
|
||||
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
@ -270,77 +226,76 @@ model_cfgs = dict(
|
||||
self_attn_kwargs=dict()
|
||||
),
|
||||
|
||||
swinnet26t=ByoaCfg(
|
||||
swinnet26t=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
fixed_input_size=True,
|
||||
self_attn_layer='swin',
|
||||
self_attn_fixed_size=True,
|
||||
self_attn_kwargs=dict(win_size=8)
|
||||
),
|
||||
swinnet50ts=ByoaCfg(
|
||||
swinnet50ts=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
||||
interleave_attn(types=('bottle', 'self_attn'), every=1, d=4, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=4, c=512, s=2, gs=0, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
fixed_input_size=True,
|
||||
act_layer='silu',
|
||||
self_attn_layer='swin',
|
||||
self_attn_fixed_size=True,
|
||||
self_attn_kwargs=dict(win_size=8)
|
||||
),
|
||||
eca_swinnext26ts=ByoaCfg(
|
||||
eca_swinnext26ts=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
|
||||
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=16, br=0.25),
|
||||
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=16, br=0.25),
|
||||
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
fixed_input_size=True,
|
||||
act_layer='silu',
|
||||
attn_layer='eca',
|
||||
self_attn_layer='swin',
|
||||
self_attn_fixed_size=True,
|
||||
self_attn_kwargs=dict(win_size=8)
|
||||
),
|
||||
|
||||
|
||||
rednet26t=ByoaCfg(
|
||||
rednet26t=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoaBlocksCfg(type='self_attn', d=2, c=256, s=1, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=2, c=512, s=2, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=2, c=256, s=1, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=2, c=512, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered', # FIXME RedNet uses involution in middle of stem
|
||||
stem_pool='maxpool',
|
||||
num_features=0,
|
||||
self_attn_layer='involution',
|
||||
self_attn_fixed_size=False,
|
||||
self_attn_kwargs=dict()
|
||||
),
|
||||
rednet50ts=ByoaCfg(
|
||||
rednet50ts=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoaBlocksCfg(type='self_attn', d=3, c=256, s=1, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=4, c=512, s=2, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=3, c=256, s=1, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=4, c=512, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25),
|
||||
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
|
||||
),
|
||||
stem_chs=64,
|
||||
stem_type='tiered',
|
||||
@ -348,161 +303,14 @@ model_cfgs = dict(
|
||||
num_features=0,
|
||||
act_layer='silu',
|
||||
self_attn_layer='involution',
|
||||
self_attn_fixed_size=False,
|
||||
self_attn_kwargs=dict()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ByoaLayerFn(LayerFn):
|
||||
self_attn: Optional[Callable] = None
|
||||
|
||||
|
||||
class SelfAttnBlock(nn.Module):
|
||||
""" ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1
|
||||
"""
|
||||
|
||||
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
|
||||
downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None,
|
||||
layers: ByoaLayerFn = None, drop_block=None, drop_path_rate=0.):
|
||||
super(SelfAttnBlock, self).__init__()
|
||||
assert layers is not None
|
||||
mid_chs = make_divisible(out_chs * bottle_ratio)
|
||||
groups = num_groups(group_size, mid_chs)
|
||||
|
||||
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
||||
self.shortcut = create_downsample(
|
||||
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
|
||||
apply_act=False, layers=layers)
|
||||
else:
|
||||
self.shortcut = nn.Identity()
|
||||
|
||||
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
|
||||
if extra_conv:
|
||||
self.conv2_kxk = layers.conv_norm_act(
|
||||
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
|
||||
groups=groups, drop_block=drop_block)
|
||||
stride = 1 # striding done via conv if enabled
|
||||
else:
|
||||
self.conv2_kxk = nn.Identity()
|
||||
opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size)
|
||||
# FIXME need to dilate self attn to have dilated network support, moop moop
|
||||
self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs)
|
||||
self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity()
|
||||
self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
||||
|
||||
def init_weights(self, zero_init_last_bn=False):
|
||||
if zero_init_last_bn:
|
||||
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
||||
if hasattr(self.self_attn, 'reset_parameters'):
|
||||
self.self_attn.reset_parameters()
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut(x)
|
||||
|
||||
x = self.conv1_1x1(x)
|
||||
x = self.conv2_kxk(x)
|
||||
x = self.self_attn(x)
|
||||
x = self.post_attn(x)
|
||||
x = self.conv3_1x1(x)
|
||||
x = self.drop_path(x)
|
||||
|
||||
x = self.act(x + shortcut)
|
||||
return x
|
||||
|
||||
register_block('self_attn', SelfAttnBlock)
|
||||
|
||||
|
||||
def _byoa_block_args(block_kwargs, block_cfg: ByoaBlocksCfg, model_cfg: ByoaCfg, feat_size=None):
|
||||
if block_cfg.type == 'self_attn' and model_cfg.self_attn_fixed_size:
|
||||
assert feat_size is not None
|
||||
block_kwargs['feat_size'] = feat_size
|
||||
return block_kwargs
|
||||
|
||||
|
||||
def get_layer_fns(cfg: ByoaCfg):
|
||||
act = get_act_layer(cfg.act_layer)
|
||||
norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act)
|
||||
conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act)
|
||||
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
||||
self_attn = partial(get_self_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
|
||||
layer_fn = ByoaLayerFn(
|
||||
conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)
|
||||
return layer_fn
|
||||
|
||||
|
||||
class ByoaNet(nn.Module):
|
||||
""" 'Bring-your-own-attention' Net
|
||||
|
||||
A ResNet inspired backbone that supports interleaving traditional residual blocks with
|
||||
'Self Attention' bottleneck blocks that replace the bottleneck kxk conv w/ a self-attention
|
||||
or similar module.
|
||||
|
||||
FIXME This class network definition is almost the same as ByobNet, I'd like to merge them but
|
||||
torchscript limitations prevent sensible inheritance overrides.
|
||||
"""
|
||||
def __init__(self, cfg: ByoaCfg, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg',
|
||||
zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
layers = get_layer_fns(cfg)
|
||||
feat_size = to_2tuple(img_size) if img_size is not None else None
|
||||
|
||||
self.feature_info = []
|
||||
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
|
||||
self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
|
||||
self.feature_info.extend(stem_feat[:-1])
|
||||
feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction'])
|
||||
|
||||
self.stages, stage_feat = create_byob_stages(
|
||||
cfg, drop_path_rate, output_stride, stem_feat[-1],
|
||||
feat_size=feat_size, layers=layers, extra_args_fn=_byoa_block_args)
|
||||
self.feature_info.extend(stage_feat[:-1])
|
||||
|
||||
prev_chs = stage_feat[-1]['num_chs']
|
||||
if cfg.num_features:
|
||||
self.num_features = int(round(cfg.width_factor * cfg.num_features))
|
||||
self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1)
|
||||
else:
|
||||
self.num_features = prev_chs
|
||||
self.final_conv = nn.Identity()
|
||||
self.feature_info += [
|
||||
dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
|
||||
|
||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||
|
||||
for n, m in self.named_modules():
|
||||
_init_weights(m, n)
|
||||
for m in self.modules():
|
||||
# call each block's weight init for block-specific overrides to init above
|
||||
if hasattr(m, 'init_weights'):
|
||||
m.init_weights(zero_init_last_bn=zero_init_last_bn)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.stages(x)
|
||||
x = self.final_conv(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
ByoaNet, variant, pretrained,
|
||||
ByobNet, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
|
||||
feature_cfg=dict(flatten_sequential=True),
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -91,6 +91,12 @@ default_cfgs = {
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth',
|
||||
interpolation='bilinear'),
|
||||
|
||||
# NOTE experimenting with alternate attention
|
||||
'eca_efficientnet_b0': _cfg(
|
||||
url=''),
|
||||
'gc_efficientnet_b0': _cfg(
|
||||
url=''),
|
||||
|
||||
'efficientnet_b0': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth'),
|
||||
'efficientnet_b1': _cfg(
|
||||
@ -1223,6 +1229,26 @@ def efficientnet_b0(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eca_efficientnet_b0(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B0 w/ ECA attn """
|
||||
# NOTE experimental config
|
||||
model = _gen_efficientnet(
|
||||
'eca_efficientnet_b0', se_layer='ecam', channel_multiplier=1.0, depth_multiplier=1.0,
|
||||
pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gc_efficientnet_b0(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B0 w/ GlobalContext """
|
||||
# NOTE experminetal config
|
||||
model = _gen_efficientnet(
|
||||
'gc_efficientnet_b0', se_layer='gc', channel_multiplier=1.0, depth_multiplier=1.0,
|
||||
pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_b1(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B1 """
|
||||
|
@ -7,7 +7,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .layers import create_conv2d, drop_path, make_divisible, get_act_fn, create_act_layer
|
||||
from .layers import create_conv2d, drop_path, make_divisible, create_act_layer
|
||||
from .layers.activations import sigmoid
|
||||
|
||||
__all__ = [
|
||||
@ -19,33 +19,32 @@ class SqueezeExcite(nn.Module):
|
||||
|
||||
Args:
|
||||
in_chs (int): input channels to layer
|
||||
se_ratio (float): ratio of squeeze reduction
|
||||
rd_ratio (float): ratio of squeeze reduction
|
||||
act_layer (nn.Module): activation layer of containing block
|
||||
gate_fn (Callable): attention gate function
|
||||
block_in_chs (int): input channels of containing block (for calculating reduction from)
|
||||
reduce_from_block (bool): calculate reduction from block input channels if True
|
||||
gate_layer (Callable): attention gate function
|
||||
force_act_layer (nn.Module): override block's activation fn if this is set/bound
|
||||
divisor (int): make reduction channels divisible by this
|
||||
rd_round_fn (Callable): specify a fn to calculate rounding of reduced chs
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_chs, se_ratio=0.25, act_layer=nn.ReLU, gate_fn=sigmoid,
|
||||
block_in_chs=None, reduce_from_block=True, force_act_layer=None, divisor=1):
|
||||
self, in_chs, rd_ratio=0.25, rd_channels=None, act_layer=nn.ReLU,
|
||||
gate_layer=nn.Sigmoid, force_act_layer=None, rd_round_fn=None):
|
||||
super(SqueezeExcite, self).__init__()
|
||||
reduced_chs = (block_in_chs or in_chs) if reduce_from_block else in_chs
|
||||
reduced_chs = make_divisible(reduced_chs * se_ratio, divisor)
|
||||
if rd_channels is None:
|
||||
rd_round_fn = rd_round_fn or round
|
||||
rd_channels = rd_round_fn(in_chs * rd_ratio)
|
||||
act_layer = force_act_layer or act_layer
|
||||
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
|
||||
self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True)
|
||||
self.act1 = create_act_layer(act_layer, inplace=True)
|
||||
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
|
||||
self.gate_fn = get_act_fn(gate_fn)
|
||||
self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True)
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
x_se = x.mean((2, 3), keepdim=True)
|
||||
x_se = self.conv_reduce(x_se)
|
||||
x_se = self.act1(x_se)
|
||||
x_se = self.conv_expand(x_se)
|
||||
return x * self.gate_fn(x_se)
|
||||
return x * self.gate(x_se)
|
||||
|
||||
|
||||
class ConvBnAct(nn.Module):
|
||||
@ -87,10 +86,9 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
"""
|
||||
def __init__(
|
||||
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
|
||||
noskip=False, pw_kernel_size=1, pw_act=False, se_ratio=0.,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
|
||||
noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
||||
se_layer=None, drop_path_rate=0.):
|
||||
super(DepthwiseSeparableConv, self).__init__()
|
||||
has_se = se_layer is not None and se_ratio > 0.
|
||||
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.has_pw_act = pw_act # activation after point-wise conv
|
||||
self.drop_path_rate = drop_path_rate
|
||||
@ -101,7 +99,7 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
self.se = se_layer(in_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity()
|
||||
self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity()
|
||||
|
||||
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
|
||||
self.bn2 = norm_layer(out_chs)
|
||||
@ -146,12 +144,11 @@ class InvertedResidual(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
|
||||
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0.,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.):
|
||||
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.):
|
||||
super(InvertedResidual, self).__init__()
|
||||
conv_kwargs = conv_kwargs or {}
|
||||
mid_chs = make_divisible(in_chs * exp_ratio)
|
||||
has_se = se_layer is not None and se_ratio > 0.
|
||||
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||
self.drop_path_rate = drop_path_rate
|
||||
|
||||
@ -168,8 +165,7 @@ class InvertedResidual(nn.Module):
|
||||
self.act2 = act_layer(inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
self.se = se_layer(
|
||||
mid_chs, se_ratio=se_ratio, act_layer=act_layer, block_in_chs=in_chs) if has_se else nn.Identity()
|
||||
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
|
||||
|
||||
# Point-wise linear projection
|
||||
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
|
||||
@ -215,8 +211,8 @@ class CondConvResidual(InvertedResidual):
|
||||
|
||||
def __init__(
|
||||
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
|
||||
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0.,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.):
|
||||
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.):
|
||||
|
||||
self.num_experts = num_experts
|
||||
conv_kwargs = dict(num_experts=self.num_experts)
|
||||
@ -224,8 +220,8 @@ class CondConvResidual(InvertedResidual):
|
||||
super(CondConvResidual, self).__init__(
|
||||
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type,
|
||||
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_layer=se_layer,
|
||||
norm_layer=norm_layer, conv_kwargs=conv_kwargs, drop_path_rate=drop_path_rate)
|
||||
pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs,
|
||||
drop_path_rate=drop_path_rate)
|
||||
|
||||
self.routing_fn = nn.Linear(in_chs, self.num_experts)
|
||||
|
||||
@ -274,8 +270,8 @@ class EdgeResidual(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, pad_type='',
|
||||
force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, se_ratio=0.,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
|
||||
force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
|
||||
super(EdgeResidual, self).__init__()
|
||||
if force_in_chs > 0:
|
||||
mid_chs = make_divisible(force_in_chs * exp_ratio)
|
||||
@ -292,8 +288,7 @@ class EdgeResidual(nn.Module):
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
self.se = SqueezeExcite(
|
||||
mid_chs, se_ratio=se_ratio, act_layer=act_layer, block_in_chs=in_chs) if has_se else nn.Identity()
|
||||
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
|
||||
|
||||
# Point-wise linear projection
|
||||
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type)
|
||||
|
@ -10,11 +10,12 @@ import logging
|
||||
import math
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .efficientnet_blocks import *
|
||||
from .layers import CondConv2d, get_condconv_initializer, get_act_layer, make_divisible
|
||||
from .layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible
|
||||
|
||||
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
|
||||
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
|
||||
@ -120,7 +121,9 @@ def _decode_block_str(block_str):
|
||||
elif v == 'hs':
|
||||
value = get_act_layer('hard_swish')
|
||||
elif v == 'sw':
|
||||
value = get_act_layer('swish')
|
||||
value = get_act_layer('swish') # aka SiLU
|
||||
elif v == 'mi':
|
||||
value = get_act_layer('mish')
|
||||
else:
|
||||
continue
|
||||
options[key] = value
|
||||
@ -265,14 +268,20 @@ class EfficientNetBuilder:
|
||||
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
|
||||
|
||||
"""
|
||||
def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels,
|
||||
def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=False,
|
||||
act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''):
|
||||
self.output_stride = output_stride
|
||||
self.pad_type = pad_type
|
||||
self.round_chs_fn = round_chs_fn
|
||||
self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs
|
||||
self.act_layer = act_layer
|
||||
self.norm_layer = norm_layer
|
||||
self.se_layer = se_layer
|
||||
self.se_layer = get_attn(se_layer)
|
||||
try:
|
||||
self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg
|
||||
self.se_has_ratio = True
|
||||
except TypeError:
|
||||
self.se_has_ratio = False
|
||||
self.drop_path_rate = drop_path_rate
|
||||
if feature_location == 'depthwise':
|
||||
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
|
||||
@ -299,16 +308,21 @@ class EfficientNetBuilder:
|
||||
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
|
||||
assert ba['act_layer'] is not None
|
||||
ba['norm_layer'] = self.norm_layer
|
||||
ba['drop_path_rate'] = drop_path_rate
|
||||
if bt != 'cn':
|
||||
ba['se_layer'] = self.se_layer
|
||||
ba['drop_path_rate'] = drop_path_rate
|
||||
se_ratio = ba.pop('se_ratio')
|
||||
if se_ratio and self.se_layer is not None:
|
||||
if not self.se_from_exp:
|
||||
# adjust se_ratio by expansion ratio if calculating se channels from block input
|
||||
se_ratio /= ba.get('exp_ratio', 1.0)
|
||||
if self.se_has_ratio:
|
||||
ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio)
|
||||
else:
|
||||
ba['se_layer'] = self.se_layer
|
||||
|
||||
if bt == 'ir':
|
||||
_log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
if ba.get('num_experts', 0) > 0:
|
||||
block = CondConvResidual(**ba)
|
||||
else:
|
||||
block = InvertedResidual(**ba)
|
||||
block = CondConvResidual(**ba) if ba.get('num_experts', 0) else InvertedResidual(**ba)
|
||||
elif bt == 'ds' or bt == 'dsa':
|
||||
_log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
block = DepthwiseSeparableConv(**ba)
|
||||
@ -418,28 +432,28 @@ def _init_weight_goog(m, n='', fix_group_fanout=True):
|
||||
if fix_group_fanout:
|
||||
fan_out //= m.groups
|
||||
init_weight_fn = get_condconv_initializer(
|
||||
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
|
||||
lambda w: nn.init.normal_(w, 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
|
||||
init_weight_fn(m.weight)
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
if fix_group_fanout:
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1.0)
|
||||
m.bias.data.zero_()
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
fan_out = m.weight.size(0) # fan-out
|
||||
fan_in = 0
|
||||
if 'routing_fn' in n:
|
||||
fan_in = m.weight.size(1)
|
||||
init_range = 1.0 / math.sqrt(fan_in + fan_out)
|
||||
m.weight.data.uniform_(-init_range, init_range)
|
||||
m.bias.data.zero_()
|
||||
nn.init.uniform_(m.weight, -init_range, init_range)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
def efficientnet_init_weights(model: nn.Module, init_fn=None):
|
||||
|
@ -40,7 +40,7 @@ default_cfgs = {
|
||||
}
|
||||
|
||||
|
||||
_SE_LAYER = partial(SqueezeExcite, gate_fn='hard_sigmoid', divisor=4)
|
||||
_SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4))
|
||||
|
||||
|
||||
class GhostModule(nn.Module):
|
||||
@ -92,7 +92,7 @@ class GhostBottleneck(nn.Module):
|
||||
self.bn_dw = None
|
||||
|
||||
# Squeeze-and-excitation
|
||||
self.se = _SE_LAYER(mid_chs, se_ratio=se_ratio) if has_se else None
|
||||
self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio) if has_se else None
|
||||
|
||||
# Point-wise linear projection
|
||||
self.ghost2 = GhostModule(mid_chs, out_chs, relu=False)
|
||||
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .efficientnet_blocks import SqueezeExcite
|
||||
from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args
|
||||
from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels
|
||||
from .helpers import build_model_with_cfg, default_cfg_for_features
|
||||
from .layers import get_act_fn
|
||||
from .mobilenetv3 import MobileNetV3, MobileNetV3Features
|
||||
@ -39,8 +39,7 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
|
||||
|
||||
"""
|
||||
num_features = 1280
|
||||
se_layer = partial(
|
||||
SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, reduce_from_block=False, divisor=8)
|
||||
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
num_features=num_features,
|
||||
|
@ -12,26 +12,28 @@ from .create_act import create_act_layer, get_act_layer, get_act_fn
|
||||
from .create_attn import get_attn, create_attn
|
||||
from .create_conv2d import create_conv2d
|
||||
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
|
||||
from .create_self_attn import get_self_attn, create_self_attn
|
||||
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
||||
from .eca import EcaModule, CecaModule
|
||||
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
|
||||
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||
from .gather_excite import GatherExcite
|
||||
from .global_context import GlobalContext
|
||||
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
|
||||
from .inplace_abn import InplaceAbn
|
||||
from .involution import Involution
|
||||
from .linear import Linear
|
||||
from .mixed_conv2d import MixedConv2d
|
||||
from .mlp import Mlp, GluMlp, GatedMlp
|
||||
from .norm import GroupNorm
|
||||
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
||||
from .norm import GroupNorm, LayerNorm2d
|
||||
from .norm_act import BatchNormAct2d, GroupNormAct
|
||||
from .padding import get_padding, get_same_padding, pad_same
|
||||
from .patch_embed import PatchEmbed
|
||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||
from .se import SEModule
|
||||
from .selective_kernel import SelectiveKernelConv
|
||||
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
||||
from .selective_kernel import SelectiveKernel
|
||||
from .separable_conv import SeparableConv2d, SeparableConvBnAct
|
||||
from .space_to_depth import SpaceToDepthModule
|
||||
from .split_attn import SplitAttnConv2d
|
||||
from .split_attn import SplitAttn
|
||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
|
@ -7,78 +7,87 @@ some tasks, especially fine-grained it seems. I may end up removing this impl.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .conv_bn_act import ConvBnAct
|
||||
from .create_act import create_act_layer, get_act_layer
|
||||
from .helpers import make_divisible
|
||||
|
||||
|
||||
class ChannelAttn(nn.Module):
|
||||
""" Original CBAM channel attention module, currently avg + max pool variant only.
|
||||
"""
|
||||
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
|
||||
def __init__(
|
||||
self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
|
||||
act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
|
||||
super(ChannelAttn, self).__init__()
|
||||
self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False)
|
||||
if not rd_channels:
|
||||
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
|
||||
self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias)
|
||||
self.act = act_layer(inplace=True)
|
||||
self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False)
|
||||
self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias)
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
x_avg = x.mean((2, 3), keepdim=True)
|
||||
x_max = F.adaptive_max_pool2d(x, 1)
|
||||
x_avg = self.fc2(self.act(self.fc1(x_avg)))
|
||||
x_max = self.fc2(self.act(self.fc1(x_max)))
|
||||
x_attn = x_avg + x_max
|
||||
return x * x_attn.sigmoid()
|
||||
x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True))))
|
||||
x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True))))
|
||||
return x * self.gate(x_avg + x_max)
|
||||
|
||||
|
||||
class LightChannelAttn(ChannelAttn):
|
||||
"""An experimental 'lightweight' that sums avg + max pool first
|
||||
"""
|
||||
def __init__(self, channels, reduction=16):
|
||||
super(LightChannelAttn, self).__init__(channels, reduction)
|
||||
def __init__(
|
||||
self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
|
||||
act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
|
||||
super(LightChannelAttn, self).__init__(
|
||||
channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias)
|
||||
|
||||
def forward(self, x):
|
||||
x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * F.adaptive_max_pool2d(x, 1)
|
||||
x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True)
|
||||
x_attn = self.fc2(self.act(self.fc1(x_pool)))
|
||||
return x * x_attn.sigmoid()
|
||||
return x * F.sigmoid(x_attn)
|
||||
|
||||
|
||||
class SpatialAttn(nn.Module):
|
||||
""" Original CBAM spatial attention module
|
||||
"""
|
||||
def __init__(self, kernel_size=7):
|
||||
def __init__(self, kernel_size=7, gate_layer='sigmoid'):
|
||||
super(SpatialAttn, self).__init__()
|
||||
self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None)
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
x_avg = torch.mean(x, dim=1, keepdim=True)
|
||||
x_max = torch.max(x, dim=1, keepdim=True)[0]
|
||||
x_attn = torch.cat([x_avg, x_max], dim=1)
|
||||
x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1)
|
||||
x_attn = self.conv(x_attn)
|
||||
return x * x_attn.sigmoid()
|
||||
return x * self.gate(x_attn)
|
||||
|
||||
|
||||
class LightSpatialAttn(nn.Module):
|
||||
"""An experimental 'lightweight' variant that sums avg_pool and max_pool results.
|
||||
"""
|
||||
def __init__(self, kernel_size=7):
|
||||
def __init__(self, kernel_size=7, gate_layer='sigmoid'):
|
||||
super(LightSpatialAttn, self).__init__()
|
||||
self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None)
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
x_avg = torch.mean(x, dim=1, keepdim=True)
|
||||
x_max = torch.max(x, dim=1, keepdim=True)[0]
|
||||
x_attn = 0.5 * x_avg + 0.5 * x_max
|
||||
x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True)
|
||||
x_attn = self.conv(x_attn)
|
||||
return x * x_attn.sigmoid()
|
||||
return x * self.gate(x_attn)
|
||||
|
||||
|
||||
class CbamModule(nn.Module):
|
||||
def __init__(self, channels, spatial_kernel_size=7):
|
||||
def __init__(
|
||||
self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
|
||||
spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
|
||||
super(CbamModule, self).__init__()
|
||||
self.channel = ChannelAttn(channels)
|
||||
self.spatial = SpatialAttn(spatial_kernel_size)
|
||||
self.channel = ChannelAttn(
|
||||
channels, rd_ratio=rd_ratio, rd_channels=rd_channels,
|
||||
rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias)
|
||||
self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.channel(x)
|
||||
@ -87,9 +96,13 @@ class CbamModule(nn.Module):
|
||||
|
||||
|
||||
class LightCbamModule(nn.Module):
|
||||
def __init__(self, channels, spatial_kernel_size=7):
|
||||
def __init__(
|
||||
self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
|
||||
spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
|
||||
super(LightCbamModule, self).__init__()
|
||||
self.channel = LightChannelAttn(channels)
|
||||
self.channel = LightChannelAttn(
|
||||
channels, rd_ratio=rd_ratio, rd_channels=rd_channels,
|
||||
rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias)
|
||||
self.spatial = LightSpatialAttn(spatial_kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -1,11 +1,23 @@
|
||||
""" Select AttentionFactory Method
|
||||
""" Attention Factory
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
from .se import SEModule, EffectiveSEModule
|
||||
from .eca import EcaModule, CecaModule
|
||||
from functools import partial
|
||||
|
||||
from .bottleneck_attn import BottleneckAttn
|
||||
from .cbam import CbamModule, LightCbamModule
|
||||
from .eca import EcaModule, CecaModule
|
||||
from .gather_excite import GatherExcite
|
||||
from .global_context import GlobalContext
|
||||
from .halo_attn import HaloAttn
|
||||
from .involution import Involution
|
||||
from .lambda_layer import LambdaLayer
|
||||
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
||||
from .selective_kernel import SelectiveKernel
|
||||
from .split_attn import SplitAttn
|
||||
from .squeeze_excite import SEModule, EffectiveSEModule
|
||||
from .swin_attn import WindowAttention
|
||||
|
||||
|
||||
def get_attn(attn_type):
|
||||
@ -15,18 +27,54 @@ def get_attn(attn_type):
|
||||
if attn_type is not None:
|
||||
if isinstance(attn_type, str):
|
||||
attn_type = attn_type.lower()
|
||||
# Lightweight attention modules (channel and/or coarse spatial).
|
||||
# Typically added to existing network architecture blocks in addition to existing convolutions.
|
||||
if attn_type == 'se':
|
||||
module_cls = SEModule
|
||||
elif attn_type == 'ese':
|
||||
module_cls = EffectiveSEModule
|
||||
elif attn_type == 'eca':
|
||||
module_cls = EcaModule
|
||||
elif attn_type == 'ecam':
|
||||
module_cls = partial(EcaModule, use_mlp=True)
|
||||
elif attn_type == 'ceca':
|
||||
module_cls = CecaModule
|
||||
elif attn_type == 'ge':
|
||||
module_cls = GatherExcite
|
||||
elif attn_type == 'gc':
|
||||
module_cls = GlobalContext
|
||||
elif attn_type == 'cbam':
|
||||
module_cls = CbamModule
|
||||
elif attn_type == 'lcbam':
|
||||
module_cls = LightCbamModule
|
||||
|
||||
# Attention / attention-like modules w/ significant params
|
||||
# Typically replace some of the existing workhorse convs in a network architecture.
|
||||
# All of these accept a stride argument and can spatially downsample the input.
|
||||
elif attn_type == 'sk':
|
||||
module_cls = SelectiveKernel
|
||||
elif attn_type == 'splat':
|
||||
module_cls = SplitAttn
|
||||
|
||||
# Self-attention / attention-like modules w/ significant compute and/or params
|
||||
# Typically replace some of the existing workhorse convs in a network architecture.
|
||||
# All of these accept a stride argument and can spatially downsample the input.
|
||||
elif attn_type == 'lambda':
|
||||
return LambdaLayer
|
||||
elif attn_type == 'bottleneck':
|
||||
return BottleneckAttn
|
||||
elif attn_type == 'halo':
|
||||
return HaloAttn
|
||||
elif attn_type == 'swin':
|
||||
return WindowAttention
|
||||
elif attn_type == 'involution':
|
||||
return Involution
|
||||
elif attn_type == 'nl':
|
||||
module_cls = NonLocalAttn
|
||||
elif attn_type == 'bat':
|
||||
module_cls = BatNonLocalAttn
|
||||
|
||||
# Woops!
|
||||
else:
|
||||
assert False, "Invalid attn module (%s)" % attn_type
|
||||
elif isinstance(attn_type, bool):
|
||||
|
@ -1,25 +0,0 @@
|
||||
from .bottleneck_attn import BottleneckAttn
|
||||
from .halo_attn import HaloAttn
|
||||
from .involution import Involution
|
||||
from .lambda_layer import LambdaLayer
|
||||
from .swin_attn import WindowAttention
|
||||
|
||||
|
||||
def get_self_attn(attn_type):
|
||||
if attn_type == 'bottleneck':
|
||||
return BottleneckAttn
|
||||
elif attn_type == 'halo':
|
||||
return HaloAttn
|
||||
elif attn_type == 'lambda':
|
||||
return LambdaLayer
|
||||
elif attn_type == 'swin':
|
||||
return WindowAttention
|
||||
elif attn_type == 'involution':
|
||||
return Involution
|
||||
else:
|
||||
assert False, f"Unknown attn type ({attn_type})"
|
||||
|
||||
|
||||
def create_self_attn(attn_type, dim, stride=1, **kwargs):
|
||||
attn_fn = get_self_attn(attn_type)
|
||||
return attn_fn(dim, stride=stride, **kwargs)
|
@ -38,6 +38,10 @@ from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from .create_act import create_act_layer
|
||||
from .helpers import make_divisible
|
||||
|
||||
|
||||
class EcaModule(nn.Module):
|
||||
"""Constructs an ECA module.
|
||||
|
||||
@ -48,23 +52,48 @@ class EcaModule(nn.Module):
|
||||
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
|
||||
(default=None. if channel size not given, use k_size given for kernel size.)
|
||||
kernel_size: Adaptive selection of kernel size (default=3)
|
||||
gamm: used in kernel_size calc, see above
|
||||
beta: used in kernel_size calc, see above
|
||||
act_layer: optional non-linearity after conv, enables conv bias, this is an experiment
|
||||
gate_layer: gating non-linearity to use
|
||||
"""
|
||||
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
|
||||
def __init__(
|
||||
self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid',
|
||||
rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False):
|
||||
super(EcaModule, self).__init__()
|
||||
assert kernel_size % 2 == 1
|
||||
if channels is not None:
|
||||
t = int(abs(math.log(channels, 2) + beta) / gamma)
|
||||
kernel_size = max(t if t % 2 else t + 1, 3)
|
||||
|
||||
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
|
||||
assert kernel_size % 2 == 1
|
||||
padding = (kernel_size - 1) // 2
|
||||
if use_mlp:
|
||||
# NOTE 'mlp' mode is a timm experiment, not in paper
|
||||
assert channels is not None
|
||||
if rd_channels is None:
|
||||
rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor)
|
||||
act_layer = act_layer or nn.ReLU
|
||||
self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True)
|
||||
self.act = create_act_layer(act_layer)
|
||||
self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True)
|
||||
else:
|
||||
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
|
||||
self.act = None
|
||||
self.conv2 = None
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv
|
||||
y = self.conv(y)
|
||||
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
|
||||
if self.conv2 is not None:
|
||||
y = self.act(y)
|
||||
y = self.conv2(y)
|
||||
y = self.gate(y).view(x.shape[0], -1, 1, 1)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
EfficientChannelAttn = EcaModule # alias
|
||||
|
||||
|
||||
class CecaModule(nn.Module):
|
||||
"""Constructs a circular ECA module.
|
||||
|
||||
@ -83,25 +112,34 @@ class CecaModule(nn.Module):
|
||||
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
|
||||
(default=None. if channel size not given, use k_size given for kernel size.)
|
||||
kernel_size: Adaptive selection of kernel size (default=3)
|
||||
gamm: used in kernel_size calc, see above
|
||||
beta: used in kernel_size calc, see above
|
||||
act_layer: optional non-linearity after conv, enables conv bias, this is an experiment
|
||||
gate_layer: gating non-linearity to use
|
||||
"""
|
||||
|
||||
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
|
||||
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'):
|
||||
super(CecaModule, self).__init__()
|
||||
assert kernel_size % 2 == 1
|
||||
if channels is not None:
|
||||
t = int(abs(math.log(channels, 2) + beta) / gamma)
|
||||
kernel_size = max(t if t % 2 else t + 1, 3)
|
||||
has_act = act_layer is not None
|
||||
assert kernel_size % 2 == 1
|
||||
|
||||
# PyTorch circular padding mode is buggy as of pytorch 1.4
|
||||
# see https://github.com/pytorch/pytorch/pull/17240
|
||||
# implement manual circular padding
|
||||
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
|
||||
self.padding = (kernel_size - 1) // 2
|
||||
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act)
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
y = x.mean((2, 3)).view(x.shape[0], 1, -1)
|
||||
# Manually implement circular padding, F.pad does not seemed to be bugged
|
||||
y = F.pad(y, (self.padding, self.padding), mode='circular')
|
||||
y = self.conv(y)
|
||||
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
|
||||
y = self.gate(y).view(x.shape[0], -1, 1, 1)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
CircularEfficientChannelAttn = CecaModule
|
||||
|
90
timm/models/layers/gather_excite.py
Normal file
90
timm/models/layers/gather_excite.py
Normal file
@ -0,0 +1,90 @@
|
||||
""" Gather-Excite Attention Block
|
||||
|
||||
Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348
|
||||
|
||||
Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet
|
||||
|
||||
I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another
|
||||
impl that covers all of the cases.
|
||||
|
||||
NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation
|
||||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import math
|
||||
|
||||
from torch import nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .create_act import create_act_layer, get_act_layer
|
||||
from .create_conv2d import create_conv2d
|
||||
from .helpers import make_divisible
|
||||
from .mlp import ConvMlp
|
||||
|
||||
|
||||
class GatherExcite(nn.Module):
|
||||
""" Gather-Excite Attention Module
|
||||
"""
|
||||
def __init__(
|
||||
self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True,
|
||||
rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'):
|
||||
super(GatherExcite, self).__init__()
|
||||
self.add_maxpool = add_maxpool
|
||||
act_layer = get_act_layer(act_layer)
|
||||
self.extent = extent
|
||||
if extra_params:
|
||||
self.gather = nn.Sequential()
|
||||
if extent == 0:
|
||||
assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params'
|
||||
self.gather.add_module(
|
||||
'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True))
|
||||
if norm_layer:
|
||||
self.gather.add_module(f'norm1', nn.BatchNorm2d(channels))
|
||||
else:
|
||||
assert extent % 2 == 0
|
||||
num_conv = int(math.log2(extent))
|
||||
for i in range(num_conv):
|
||||
self.gather.add_module(
|
||||
f'conv{i + 1}',
|
||||
create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True))
|
||||
if norm_layer:
|
||||
self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels))
|
||||
if i != num_conv - 1:
|
||||
self.gather.add_module(f'act{i + 1}', act_layer(inplace=True))
|
||||
else:
|
||||
self.gather = None
|
||||
if self.extent == 0:
|
||||
self.gk = 0
|
||||
self.gs = 0
|
||||
else:
|
||||
assert extent % 2 == 0
|
||||
self.gk = self.extent * 2 - 1
|
||||
self.gs = self.extent
|
||||
|
||||
if not rd_channels:
|
||||
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
|
||||
self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity()
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
size = x.shape[-2:]
|
||||
if self.gather is not None:
|
||||
x_ge = self.gather(x)
|
||||
else:
|
||||
if self.extent == 0:
|
||||
# global extent
|
||||
x_ge = x.mean(dim=(2, 3), keepdims=True)
|
||||
if self.add_maxpool:
|
||||
# experimental codepath, may remove or change
|
||||
x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True)
|
||||
else:
|
||||
x_ge = F.avg_pool2d(
|
||||
x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False)
|
||||
if self.add_maxpool:
|
||||
# experimental codepath, may remove or change
|
||||
x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2)
|
||||
x_ge = self.mlp(x_ge)
|
||||
if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1:
|
||||
x_ge = F.interpolate(x_ge, size=size)
|
||||
return x * self.gate(x_ge)
|
67
timm/models/layers/global_context.py
Normal file
67
timm/models/layers/global_context.py
Normal file
@ -0,0 +1,67 @@
|
||||
""" Global Context Attention Block
|
||||
|
||||
Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond`
|
||||
- https://arxiv.org/abs/1904.11492
|
||||
|
||||
Official code consulted as reference: https://github.com/xvjiarui/GCNet
|
||||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
from torch import nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .create_act import create_act_layer, get_act_layer
|
||||
from .helpers import make_divisible
|
||||
from .mlp import ConvMlp
|
||||
from .norm import LayerNorm2d
|
||||
|
||||
|
||||
class GlobalContext(nn.Module):
|
||||
|
||||
def __init__(self, channels, use_attn=True, fuse_add=True, fuse_scale=False, init_last_zero=False,
|
||||
rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'):
|
||||
super(GlobalContext, self).__init__()
|
||||
act_layer = get_act_layer(act_layer)
|
||||
|
||||
self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None
|
||||
|
||||
if rd_channels is None:
|
||||
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
|
||||
if fuse_add:
|
||||
self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
|
||||
else:
|
||||
self.mlp_add = None
|
||||
if fuse_scale:
|
||||
self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
|
||||
else:
|
||||
self.mlp_scale = None
|
||||
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
self.init_last_zero = init_last_zero
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
if self.conv_attn is not None:
|
||||
nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu')
|
||||
if self.mlp_add is not None:
|
||||
nn.init.zeros_(self.mlp_add.fc2.weight)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
|
||||
if self.conv_attn is not None:
|
||||
attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W)
|
||||
attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1)
|
||||
context = x.reshape(B, C, H * W).unsqueeze(1) @ attn
|
||||
context = context.view(B, C, 1, 1)
|
||||
else:
|
||||
context = x.mean(dim=(2, 3), keepdim=True)
|
||||
|
||||
if self.mlp_scale is not None:
|
||||
mlp_x = self.mlp_scale(context)
|
||||
x = x * self.gate(mlp_x)
|
||||
if self.mlp_add is not None:
|
||||
mlp_x = self.mlp_add(context)
|
||||
x = x + mlp_x
|
||||
|
||||
return x
|
@ -28,4 +28,4 @@ def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < round_limit * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
return new_v
|
||||
|
@ -16,7 +16,7 @@ class Involution(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
group_size=16,
|
||||
reduction_ratio=4,
|
||||
rd_ratio=4,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
act_layer=nn.ReLU,
|
||||
):
|
||||
@ -28,12 +28,12 @@ class Involution(nn.Module):
|
||||
self.groups = self.channels // self.group_size
|
||||
self.conv1 = ConvBnAct(
|
||||
in_channels=channels,
|
||||
out_channels=channels // reduction_ratio,
|
||||
out_channels=channels // rd_ratio,
|
||||
kernel_size=1,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer)
|
||||
self.conv2 = self.conv = create_conv2d(
|
||||
in_channels=channels // reduction_ratio,
|
||||
in_channels=channels // rd_ratio,
|
||||
out_channels=kernel_size**2 * self.groups,
|
||||
kernel_size=1,
|
||||
stride=1)
|
||||
|
@ -77,3 +77,26 @@ class GatedMlp(nn.Module):
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvMlp(nn.Module):
|
||||
""" MLP using 1x1 convs that keeps spatial dims
|
||||
"""
|
||||
def __init__(
|
||||
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True)
|
||||
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.norm(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
145
timm/models/layers/non_local_attn.py
Normal file
145
timm/models/layers/non_local_attn.py
Normal file
@ -0,0 +1,145 @@
|
||||
""" Bilinear-Attention-Transform and Non-Local Attention
|
||||
|
||||
Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms`
|
||||
- https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html
|
||||
Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .conv_bn_act import ConvBnAct
|
||||
from .helpers import make_divisible
|
||||
|
||||
|
||||
class NonLocalAttn(nn.Module):
|
||||
"""Spatial NL block for image classification.
|
||||
|
||||
This was adapted from https://github.com/BA-Transform/BAT-Image-Classification
|
||||
Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs):
|
||||
super(NonLocalAttn, self).__init__()
|
||||
if rd_channels is None:
|
||||
rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
|
||||
self.scale = in_channels ** -0.5 if use_scale else 1.0
|
||||
self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
|
||||
self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
|
||||
self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
|
||||
self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True)
|
||||
self.norm = nn.BatchNorm2d(in_channels)
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
|
||||
t = self.t(x)
|
||||
p = self.p(x)
|
||||
g = self.g(x)
|
||||
|
||||
B, C, H, W = t.size()
|
||||
t = t.view(B, C, -1).permute(0, 2, 1)
|
||||
p = p.view(B, C, -1)
|
||||
g = g.view(B, C, -1).permute(0, 2, 1)
|
||||
|
||||
att = torch.bmm(t, p) * self.scale
|
||||
att = F.softmax(att, dim=2)
|
||||
x = torch.bmm(att, g)
|
||||
|
||||
x = x.permute(0, 2, 1).reshape(B, C, H, W)
|
||||
x = self.z(x)
|
||||
x = self.norm(x) + shortcut
|
||||
|
||||
return x
|
||||
|
||||
def reset_parameters(self):
|
||||
for name, m in self.named_modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if len(list(m.parameters())) > 1:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 0)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.GroupNorm):
|
||||
nn.init.constant_(m.weight, 0)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class BilinearAttnTransform(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
super(BilinearAttnTransform, self).__init__()
|
||||
|
||||
self.conv1 = ConvBnAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer)
|
||||
self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1))
|
||||
self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size))
|
||||
self.conv2 = ConvBnAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
|
||||
self.block_size = block_size
|
||||
self.groups = groups
|
||||
self.in_channels = in_channels
|
||||
|
||||
def resize_mat(self, x, t):
|
||||
B, C, block_size, block_size1 = x.shape
|
||||
assert block_size == block_size1
|
||||
if t <= 1:
|
||||
return x
|
||||
x = x.view(B * C, -1, 1, 1)
|
||||
x = x * torch.eye(t, t, dtype=x.dtype, device=x.device)
|
||||
x = x.view(B * C, block_size, block_size, t, t)
|
||||
x = torch.cat(torch.split(x, 1, dim=1), dim=3)
|
||||
x = torch.cat(torch.split(x, 1, dim=2), dim=4)
|
||||
x = x.view(B, C, block_size * t, block_size * t)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0
|
||||
B, C, H, W = x.shape
|
||||
out = self.conv1(x)
|
||||
rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
|
||||
cp = F.adaptive_max_pool2d(out, (1, self.block_size))
|
||||
p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size)
|
||||
q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size)
|
||||
p = F.sigmoid(p)
|
||||
q = F.sigmoid(q)
|
||||
p = p / p.sum(dim=3, keepdim=True)
|
||||
q = q / q.sum(dim=2, keepdim=True)
|
||||
p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
|
||||
0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
|
||||
p = p.view(B, C, self.block_size, self.block_size)
|
||||
q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
|
||||
0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
|
||||
q = q.view(B, C, self.block_size, self.block_size)
|
||||
p = self.resize_mat(p, H // self.block_size)
|
||||
q = self.resize_mat(q, W // self.block_size)
|
||||
y = p.matmul(x)
|
||||
y = y.matmul(q)
|
||||
|
||||
y = self.conv2(y)
|
||||
return y
|
||||
|
||||
|
||||
class BatNonLocalAttn(nn.Module):
|
||||
""" BAT
|
||||
Adapted from: https://github.com/BA-Transform/BAT-Image-Classification
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
|
||||
drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_):
|
||||
super().__init__()
|
||||
if rd_channels is None:
|
||||
rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
|
||||
self.conv1 = ConvBnAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
|
||||
self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer)
|
||||
self.conv2 = ConvBnAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
|
||||
self.dropout = nn.Dropout2d(p=drop_rate)
|
||||
|
||||
def forward(self, x):
|
||||
xl = self.conv1(x)
|
||||
y = self.ba(xl)
|
||||
y = self.conv2(y)
|
||||
y = self.dropout(y)
|
||||
return y + x
|
@ -12,3 +12,12 @@ class GroupNorm(nn.GroupNorm):
|
||||
|
||||
def forward(self, x):
|
||||
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
class LayerNorm2d(nn.LayerNorm):
|
||||
""" Layernorm for channels of '2d' spatial BCHW tensors """
|
||||
def __init__(self, num_channels):
|
||||
super().__init__([num_channels, 1, 1])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
|
@ -1,50 +0,0 @@
|
||||
from torch import nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .create_act import create_act_layer
|
||||
from .helpers import make_divisible
|
||||
|
||||
|
||||
class SEModule(nn.Module):
|
||||
""" SE Module as defined in original SE-Nets with a few additions
|
||||
Additions include:
|
||||
* min_channels can be specified to keep reduced channel count at a minimum (default: 8)
|
||||
* divisor can be specified to keep channels rounded to specified values (default: 1)
|
||||
* reduction channels can be specified directly by arg (if reduction_channels is set)
|
||||
* reduction channels can be specified by float ratio (if reduction_ratio is set)
|
||||
"""
|
||||
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, gate_layer='sigmoid',
|
||||
reduction_ratio=None, reduction_channels=None, min_channels=8, divisor=1):
|
||||
super(SEModule, self).__init__()
|
||||
if reduction_channels is not None:
|
||||
reduction_channels = reduction_channels # direct specification highest priority, no rounding/min done
|
||||
elif reduction_ratio is not None:
|
||||
reduction_channels = make_divisible(channels * reduction_ratio, divisor, min_channels)
|
||||
else:
|
||||
reduction_channels = make_divisible(channels // reduction, divisor, min_channels)
|
||||
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
|
||||
self.act = act_layer(inplace=True)
|
||||
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
x_se = x.mean((2, 3), keepdim=True)
|
||||
x_se = self.fc1(x_se)
|
||||
x_se = self.act(x_se)
|
||||
x_se = self.fc2(x_se)
|
||||
return x * self.gate(x_se)
|
||||
|
||||
|
||||
class EffectiveSEModule(nn.Module):
|
||||
""" 'Effective Squeeze-Excitation
|
||||
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
|
||||
"""
|
||||
def __init__(self, channels, gate_layer='hard_sigmoid'):
|
||||
super(EffectiveSEModule, self).__init__()
|
||||
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
x_se = x.mean((2, 3), keepdim=True)
|
||||
x_se = self.fc(x_se)
|
||||
return x * self.gate(x_se)
|
@ -8,6 +8,7 @@ import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from .conv_bn_act import ConvBnAct
|
||||
from .helpers import make_divisible
|
||||
|
||||
|
||||
def _kernel_valid(k):
|
||||
@ -45,10 +46,10 @@ class SelectiveKernelAttn(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class SelectiveKernelConv(nn.Module):
|
||||
class SelectiveKernel(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
|
||||
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
|
||||
def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1,
|
||||
rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True,
|
||||
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
|
||||
""" Selective Kernel Convolution Module
|
||||
|
||||
@ -66,8 +67,7 @@ class SelectiveKernelConv(nn.Module):
|
||||
stride (int): stride for convolutions
|
||||
dilation (int): dilation for module as a whole, impacts dilation of each branch
|
||||
groups (int): number of groups for each branch
|
||||
attn_reduction (int, float): reduction factor for attention features
|
||||
min_attn_channels (int): minimum attention feature channels
|
||||
rd_ratio (int, float): reduction factor for attention features
|
||||
keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
|
||||
split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
|
||||
can be viewed as grouping by path, output expands to module out_channels count
|
||||
@ -75,7 +75,8 @@ class SelectiveKernelConv(nn.Module):
|
||||
act_layer (nn.Module): activation layer to use
|
||||
norm_layer (nn.Module): batchnorm/norm layer to use
|
||||
"""
|
||||
super(SelectiveKernelConv, self).__init__()
|
||||
super(SelectiveKernel, self).__init__()
|
||||
out_channels = out_channels or in_channels
|
||||
kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation
|
||||
_kernel_valid(kernel_size)
|
||||
if not isinstance(kernel_size, list):
|
||||
@ -101,7 +102,7 @@ class SelectiveKernelConv(nn.Module):
|
||||
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
|
||||
for k, d in zip(kernel_size, dilation)])
|
||||
|
||||
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
|
||||
attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)
|
||||
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
|
||||
self.drop_block = drop_block
|
||||
|
||||
|
@ -10,6 +10,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .helpers import make_divisible
|
||||
|
||||
|
||||
class RadixSoftmax(nn.Module):
|
||||
def __init__(self, radix, cardinality):
|
||||
@ -28,41 +30,37 @@ class RadixSoftmax(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class SplitAttnConv2d(nn.Module):
|
||||
"""Split-Attention Conv2d
|
||||
class SplitAttn(nn.Module):
|
||||
"""Split-Attention (aka Splat)
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
|
||||
dilation=1, groups=1, bias=False, radix=2, reduction_factor=4,
|
||||
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None,
|
||||
dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
|
||||
act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs):
|
||||
super(SplitAttnConv2d, self).__init__()
|
||||
super(SplitAttn, self).__init__()
|
||||
out_channels = out_channels or in_channels
|
||||
self.radix = radix
|
||||
self.drop_block = drop_block
|
||||
mid_chs = out_channels * radix
|
||||
attn_chs = max(in_channels * radix // reduction_factor, 32)
|
||||
if rd_channels is None:
|
||||
attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)
|
||||
else:
|
||||
attn_chs = rd_channels * radix
|
||||
|
||||
padding = kernel_size // 2 if padding is None else padding
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, mid_chs, kernel_size, stride, padding, dilation,
|
||||
groups=groups * radix, bias=bias, **kwargs)
|
||||
self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None
|
||||
self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()
|
||||
self.act0 = act_layer(inplace=True)
|
||||
self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
|
||||
self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None
|
||||
self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
|
||||
self.rsoftmax = RadixSoftmax(radix, groups)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
return self.conv.in_channels
|
||||
|
||||
@property
|
||||
def out_channels(self):
|
||||
return self.fc1.out_channels
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.bn0 is not None:
|
||||
x = self.bn0(x)
|
||||
x = self.bn0(x)
|
||||
if self.drop_block is not None:
|
||||
x = self.drop_block(x)
|
||||
x = self.act0(x)
|
||||
@ -73,10 +71,9 @@ class SplitAttnConv2d(nn.Module):
|
||||
x_gap = x.sum(dim=1)
|
||||
else:
|
||||
x_gap = x
|
||||
x_gap = F.adaptive_avg_pool2d(x_gap, 1)
|
||||
x_gap = x_gap.mean((2, 3), keepdim=True)
|
||||
x_gap = self.fc1(x_gap)
|
||||
if self.bn1 is not None:
|
||||
x_gap = self.bn1(x_gap)
|
||||
x_gap = self.bn1(x_gap)
|
||||
x_gap = self.act1(x_gap)
|
||||
x_attn = self.fc2(x_gap)
|
||||
|
||||
|
74
timm/models/layers/squeeze_excite.py
Normal file
74
timm/models/layers/squeeze_excite.py
Normal file
@ -0,0 +1,74 @@
|
||||
""" Squeeze-and-Excitation Channel Attention
|
||||
|
||||
An SE implementation originally based on PyTorch SE-Net impl.
|
||||
Has since evolved with additional functionality / configuration.
|
||||
|
||||
Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507
|
||||
|
||||
Also included is Effective Squeeze-Excitation (ESE).
|
||||
Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
|
||||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
from torch import nn as nn
|
||||
|
||||
from .create_act import create_act_layer
|
||||
from .helpers import make_divisible
|
||||
|
||||
|
||||
class SEModule(nn.Module):
|
||||
""" SE Module as defined in original SE-Nets with a few additions
|
||||
Additions include:
|
||||
* divisor can be specified to keep channels % div == 0 (default: 8)
|
||||
* reduction channels can be specified directly by arg (if rd_channels is set)
|
||||
* reduction channels can be specified by float rd_ratio (default: 1/16)
|
||||
* global max pooling can be added to the squeeze aggregation
|
||||
* customizable activation, normalization, and gate layer
|
||||
"""
|
||||
def __init__(
|
||||
self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False,
|
||||
act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'):
|
||||
super(SEModule, self).__init__()
|
||||
self.add_maxpool = add_maxpool
|
||||
if not rd_channels:
|
||||
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
|
||||
self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True)
|
||||
self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity()
|
||||
self.act = create_act_layer(act_layer, inplace=True)
|
||||
self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True)
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
x_se = x.mean((2, 3), keepdim=True)
|
||||
if self.add_maxpool:
|
||||
# experimental codepath, may remove or change
|
||||
x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
|
||||
x_se = self.fc1(x_se)
|
||||
x_se = self.act(self.bn(x_se))
|
||||
x_se = self.fc2(x_se)
|
||||
return x * self.gate(x_se)
|
||||
|
||||
|
||||
SqueezeExcite = SEModule # alias
|
||||
|
||||
|
||||
class EffectiveSEModule(nn.Module):
|
||||
""" 'Effective Squeeze-Excitation
|
||||
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
|
||||
"""
|
||||
def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_):
|
||||
super(EffectiveSEModule, self).__init__()
|
||||
self.add_maxpool = add_maxpool
|
||||
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
x_se = x.mean((2, 3), keepdim=True)
|
||||
if self.add_maxpool:
|
||||
# experimental codepath, may remove or change
|
||||
x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
|
||||
x_se = self.fc(x_se)
|
||||
return x * self.gate(x_se)
|
||||
|
||||
|
||||
EffectiveSqueezeExcite = EffectiveSEModule # alias
|
@ -72,6 +72,10 @@ default_cfgs = {
|
||||
'tf_mobilenetv3_small_minimal_100': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||
|
||||
'fbnetv3_b': _cfg(),
|
||||
'fbnetv3_d': _cfg(),
|
||||
'fbnetv3_g': _cfg(),
|
||||
}
|
||||
|
||||
|
||||
@ -86,7 +90,7 @@ class MobileNetV3(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
|
||||
pad_type='', act_layer=None, norm_layer=None, se_layer=None,
|
||||
pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
|
||||
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
|
||||
super(MobileNetV3, self).__init__()
|
||||
act_layer = act_layer or nn.ReLU
|
||||
@ -104,7 +108,7 @@ class MobileNetV3(nn.Module):
|
||||
|
||||
# Middle stages (IR/ER/DS Blocks)
|
||||
builder = EfficientNetBuilder(
|
||||
output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn,
|
||||
output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,
|
||||
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate)
|
||||
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
||||
self.feature_info = builder.features
|
||||
@ -161,8 +165,8 @@ class MobileNetV3Features(nn.Module):
|
||||
and object detection models.
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
|
||||
in_chans=3, stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels,
|
||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
|
||||
stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=True,
|
||||
act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
|
||||
super(MobileNetV3Features, self).__init__()
|
||||
act_layer = act_layer or nn.ReLU
|
||||
@ -178,7 +182,7 @@ class MobileNetV3Features(nn.Module):
|
||||
|
||||
# Middle stages (IR/ER/DS Blocks)
|
||||
builder = EfficientNetBuilder(
|
||||
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn,
|
||||
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,
|
||||
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer,
|
||||
drop_path_rate=drop_path_rate, feature_location=feature_location)
|
||||
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
||||
@ -262,7 +266,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
|
||||
se_layer=partial(SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), reduce_from_block=False),
|
||||
se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid'),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_mnv3(variant, pretrained, **model_kwargs)
|
||||
@ -350,8 +354,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c960'], # hard-swish
|
||||
]
|
||||
se_layer = partial(
|
||||
SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, reduce_from_block=False, divisor=8)
|
||||
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
num_features=num_features,
|
||||
@ -366,6 +369,67 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
|
||||
return model
|
||||
|
||||
|
||||
def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
""" FBNetV3
|
||||
Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining`
|
||||
- https://arxiv.org/abs/2006.02049
|
||||
FIXME untested, this is a preliminary impl of some FBNet-V3 variants.
|
||||
"""
|
||||
vl = variant.split('_')[-1]
|
||||
if vl in ('a', 'b'):
|
||||
stem_size = 16
|
||||
arch_def = [
|
||||
['ds_r2_k3_s1_e1_c16'],
|
||||
['ir_r1_k5_s2_e4_c24', 'ir_r3_k5_s1_e2_c24'],
|
||||
['ir_r1_k5_s2_e5_c40_se0.25', 'ir_r4_k5_s1_e3_c40_se0.25'],
|
||||
['ir_r1_k5_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'],
|
||||
['ir_r1_k3_s1_e5_c120_se0.25', 'ir_r5_k5_s1_e3_c120_se0.25'],
|
||||
['ir_r1_k3_s2_e6_c184_se0.25', 'ir_r5_k5_s1_e4_c184_se0.25', 'ir_r1_k5_s1_e6_c224_se0.25'],
|
||||
['cn_r1_k1_s1_c1344'],
|
||||
]
|
||||
elif vl == 'd':
|
||||
stem_size = 24
|
||||
arch_def = [
|
||||
['ds_r2_k3_s1_e1_c16'],
|
||||
['ir_r1_k3_s2_e5_c24', 'ir_r5_k3_s1_e2_c24'],
|
||||
['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r4_k3_s1_e3_c40_se0.25'],
|
||||
['ir_r1_k3_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'],
|
||||
['ir_r1_k3_s1_e5_c128_se0.25', 'ir_r6_k5_s1_e3_c128_se0.25'],
|
||||
['ir_r1_k3_s2_e6_c208_se0.25', 'ir_r5_k5_s1_e5_c208_se0.25', 'ir_r1_k5_s1_e6_c240_se0.25'],
|
||||
['cn_r1_k1_s1_c1440'],
|
||||
]
|
||||
elif vl == 'g':
|
||||
stem_size = 32
|
||||
arch_def = [
|
||||
['ds_r3_k3_s1_e1_c24'],
|
||||
['ir_r1_k5_s2_e4_c40', 'ir_r4_k5_s1_e2_c40'],
|
||||
['ir_r1_k5_s2_e4_c56_se0.25', 'ir_r4_k5_s1_e3_c56_se0.25'],
|
||||
['ir_r1_k5_s2_e5_c104', 'ir_r4_k3_s1_e3_c104'],
|
||||
['ir_r1_k3_s1_e5_c160_se0.25', 'ir_r8_k5_s1_e3_c160_se0.25'],
|
||||
['ir_r1_k3_s2_e6_c264_se0.25', 'ir_r6_k5_s1_e5_c264_se0.25', 'ir_r2_k5_s1_e6_c288_se0.25'],
|
||||
['cn_r1_k1_s1_c1728'],
|
||||
]
|
||||
else:
|
||||
raise NotImplemented
|
||||
round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.95)
|
||||
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=round_chs_fn)
|
||||
act_layer = resolve_act_layer(kwargs, 'hard_swish')
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
num_features=1984,
|
||||
head_bias=False,
|
||||
stem_size=stem_size,
|
||||
round_chs_fn=round_chs_fn,
|
||||
se_from_exp=False,
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
act_layer=act_layer,
|
||||
se_layer=se_layer,
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_mnv3(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_large_075(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 """
|
||||
@ -474,3 +538,24 @@ def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def fbnetv3_b(pretrained=False, **kwargs):
|
||||
""" FBNetV3-B """
|
||||
model = _gen_fbnetv3('fbnetv3_b', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def fbnetv3_d(pretrained=False, **kwargs):
|
||||
""" FBNetV3-D """
|
||||
model = _gen_fbnetv3('fbnetv3_d', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def fbnetv3_g(pretrained=False, **kwargs):
|
||||
""" FBNetV3-G """
|
||||
model = _gen_fbnetv3('fbnetv3_g', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
@ -182,7 +182,7 @@ def _nfres_cfg(
|
||||
|
||||
def _nfreg_cfg(depths, channels=(48, 104, 208, 440)):
|
||||
num_features = 1280 * channels[-1] // 440
|
||||
attn_kwargs = dict(reduction_ratio=0.5, divisor=8)
|
||||
attn_kwargs = dict(rd_ratio=0.5)
|
||||
cfg = NfCfg(
|
||||
depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25,
|
||||
num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs)
|
||||
@ -193,7 +193,7 @@ def _nfnet_cfg(
|
||||
depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2.,
|
||||
act_layer='gelu', attn_layer='se', attn_kwargs=None):
|
||||
num_features = int(channels[-1] * feat_mult)
|
||||
attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(reduction_ratio=0.5, divisor=8)
|
||||
attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5)
|
||||
cfg = NfCfg(
|
||||
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size,
|
||||
bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer,
|
||||
@ -202,11 +202,10 @@ def _nfnet_cfg(
|
||||
|
||||
|
||||
def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True):
|
||||
attn_kwargs = dict(reduction_ratio=0.5, divisor=8)
|
||||
cfg = NfCfg(
|
||||
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128,
|
||||
bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit,
|
||||
num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=attn_kwargs)
|
||||
num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=dict(rd_ratio=0.5))
|
||||
return cfg
|
||||
|
||||
|
||||
@ -243,7 +242,7 @@ model_cfgs = dict(
|
||||
# Experimental 'light' versions of NFNet-F that are little leaner
|
||||
nfnet_l0=_nfnet_cfg(
|
||||
depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
|
||||
attn_kwargs=dict(reduction_ratio=0.25, divisor=8), act_layer='silu'),
|
||||
attn_kwargs=dict(rd_ratio=0.25, rd_divisor=8), act_layer='silu'),
|
||||
eca_nfnet_l0=_nfnet_cfg(
|
||||
depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
|
||||
attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
|
||||
@ -272,9 +271,9 @@ model_cfgs = dict(
|
||||
nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)),
|
||||
nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)),
|
||||
|
||||
nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
|
||||
nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
|
||||
nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
|
||||
nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)),
|
||||
nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)),
|
||||
nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)),
|
||||
|
||||
nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()),
|
||||
nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()),
|
||||
|
@ -146,7 +146,7 @@ class Bottleneck(nn.Module):
|
||||
groups=groups, **cargs)
|
||||
if se_ratio:
|
||||
se_channels = int(round(in_chs * se_ratio))
|
||||
self.se = SEModule(bottleneck_chs, reduction_channels=se_channels)
|
||||
self.se = SEModule(bottleneck_chs, rd_channels=se_channels)
|
||||
else:
|
||||
self.se = None
|
||||
cargs['act_layer'] = None
|
||||
|
@ -11,7 +11,7 @@ from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import SplitAttnConv2d
|
||||
from .layers import SplitAttn
|
||||
from .registry import register_model
|
||||
from .resnet import ResNet
|
||||
|
||||
@ -83,11 +83,11 @@ class ResNestBottleneck(nn.Module):
|
||||
self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None
|
||||
|
||||
if self.radix >= 1:
|
||||
self.conv2 = SplitAttnConv2d(
|
||||
self.conv2 = SplitAttn(
|
||||
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
|
||||
dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block)
|
||||
self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness
|
||||
self.act2 = None
|
||||
self.bn2 = nn.Identity()
|
||||
self.act2 = nn.Identity()
|
||||
else:
|
||||
self.conv2 = nn.Conv2d(
|
||||
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
|
||||
@ -117,11 +117,10 @@ class ResNestBottleneck(nn.Module):
|
||||
out = self.avd_first(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
if self.bn2 is not None:
|
||||
out = self.bn2(out)
|
||||
if self.drop_block is not None:
|
||||
out = self.drop_block(out)
|
||||
out = self.act2(out)
|
||||
out = self.bn2(out)
|
||||
if self.drop_block is not None:
|
||||
out = self.drop_block(out)
|
||||
out = self.act2(out)
|
||||
|
||||
if self.avd_last is not None:
|
||||
out = self.avd_last(out)
|
||||
|
@ -1122,7 +1122,7 @@ def resnetrs50(pretrained=False, **kwargs):
|
||||
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
|
||||
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
|
||||
"""
|
||||
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
|
||||
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
|
||||
@ -1135,7 +1135,7 @@ def resnetrs101(pretrained=False, **kwargs):
|
||||
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
|
||||
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
|
||||
"""
|
||||
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
|
||||
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
|
||||
@ -1148,7 +1148,7 @@ def resnetrs152(pretrained=False, **kwargs):
|
||||
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
|
||||
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
|
||||
"""
|
||||
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
|
||||
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
|
||||
@ -1161,7 +1161,7 @@ def resnetrs200(pretrained=False, **kwargs):
|
||||
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
|
||||
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
|
||||
"""
|
||||
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
|
||||
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
|
||||
@ -1174,7 +1174,7 @@ def resnetrs270(pretrained=False, **kwargs):
|
||||
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
|
||||
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
|
||||
"""
|
||||
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
|
||||
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
|
||||
@ -1188,7 +1188,7 @@ def resnetrs350(pretrained=False, **kwargs):
|
||||
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
|
||||
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
|
||||
"""
|
||||
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
|
||||
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
|
||||
@ -1201,7 +1201,7 @@ def resnetrs420(pretrained=False, **kwargs):
|
||||
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
|
||||
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
|
||||
"""
|
||||
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
|
||||
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
|
||||
|
@ -11,11 +11,12 @@ Copyright 2020 Ross Wightman
|
||||
"""
|
||||
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
from math import ceil
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible
|
||||
from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible, SEModule
|
||||
from .registry import register_model
|
||||
from .efficientnet_builder import efficientnet_init_weights
|
||||
|
||||
@ -48,26 +49,7 @@ default_cfgs = dict(
|
||||
url=''),
|
||||
)
|
||||
|
||||
|
||||
class SEWithNorm(nn.Module):
|
||||
|
||||
def __init__(self, channels, se_ratio=1 / 12., act_layer=nn.ReLU, divisor=1, reduction_channels=None,
|
||||
gate_layer='sigmoid'):
|
||||
super(SEWithNorm, self).__init__()
|
||||
reduction_channels = reduction_channels or make_divisible(int(channels * se_ratio), divisor=divisor)
|
||||
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
|
||||
self.bn = nn.BatchNorm2d(reduction_channels)
|
||||
self.act = act_layer(inplace=True)
|
||||
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
x_se = x.mean((2, 3), keepdim=True)
|
||||
x_se = self.fc1(x_se)
|
||||
x_se = self.bn(x_se)
|
||||
x_se = self.act(x_se)
|
||||
x_se = self.fc2(x_se)
|
||||
return x * self.gate(x_se)
|
||||
SEWithNorm = partial(SEModule, norm_layer=nn.BatchNorm2d)
|
||||
|
||||
|
||||
class LinearBottleneck(nn.Module):
|
||||
@ -86,7 +68,10 @@ class LinearBottleneck(nn.Module):
|
||||
self.conv_exp = None
|
||||
|
||||
self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False)
|
||||
self.se = SEWithNorm(dw_chs, se_ratio=se_ratio, divisor=ch_div) if se_ratio > 0. else None
|
||||
if se_ratio > 0:
|
||||
self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div))
|
||||
else:
|
||||
self.se = None
|
||||
self.act_dw = create_act_layer(dw_act_layer)
|
||||
|
||||
self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False)
|
||||
|
@ -14,7 +14,7 @@ from torch import nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
|
||||
from .layers import SelectiveKernel, ConvBnAct, create_attn
|
||||
from .registry import register_model
|
||||
from .resnet import ResNet
|
||||
|
||||
@ -59,7 +59,7 @@ class SelectiveKernelBasic(nn.Module):
|
||||
outplanes = planes * self.expansion
|
||||
first_dilation = first_dilation or dilation
|
||||
|
||||
self.conv1 = SelectiveKernelConv(
|
||||
self.conv1 = SelectiveKernel(
|
||||
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
|
||||
conv_kwargs['act_layer'] = None
|
||||
self.conv2 = ConvBnAct(
|
||||
@ -107,7 +107,7 @@ class SelectiveKernelBottleneck(nn.Module):
|
||||
first_dilation = first_dilation or dilation
|
||||
|
||||
self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
|
||||
self.conv2 = SelectiveKernelConv(
|
||||
self.conv2 = SelectiveKernel(
|
||||
first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality,
|
||||
**conv_kwargs, **sk_kwargs)
|
||||
conv_kwargs['act_layer'] = None
|
||||
@ -153,10 +153,7 @@ def skresnet18(pretrained=False, **kwargs):
|
||||
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
||||
variation splits the input channels to the selective convolutions to keep param count down.
|
||||
"""
|
||||
sk_kwargs = dict(
|
||||
min_attn_channels=16,
|
||||
attn_reduction=8,
|
||||
split_input=True)
|
||||
sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True)
|
||||
model_args = dict(
|
||||
block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs),
|
||||
zero_init_last_bn=False, **kwargs)
|
||||
@ -170,10 +167,7 @@ def skresnet34(pretrained=False, **kwargs):
|
||||
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
||||
variation splits the input channels to the selective convolutions to keep param count down.
|
||||
"""
|
||||
sk_kwargs = dict(
|
||||
min_attn_channels=16,
|
||||
attn_reduction=8,
|
||||
split_input=True)
|
||||
sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True)
|
||||
model_args = dict(
|
||||
block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
|
||||
zero_init_last_bn=False, **kwargs)
|
||||
@ -213,8 +207,9 @@ def skresnext50_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
|
||||
the SKNet-50 model in the Select Kernel Paper
|
||||
"""
|
||||
sk_kwargs = dict(rd_ratio=1/16, rd_divisor=32, split_input=False)
|
||||
model_args = dict(
|
||||
block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
|
||||
zero_init_last_bn=False, **kwargs)
|
||||
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
||||
return _create_skresnet('skresnext50_32x4d', pretrained, **model_args)
|
||||
|
||||
|
@ -84,8 +84,8 @@ class BasicBlock(nn.Module):
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
reduction_chs = max(planes * self.expansion // 4, 64)
|
||||
self.se = SEModule(planes * self.expansion, reduction_channels=reduction_chs) if use_se else None
|
||||
rd_chs = max(planes * self.expansion // 4, 64)
|
||||
self.se = SEModule(planes * self.expansion, rd_channels=rd_chs) if use_se else None
|
||||
|
||||
def forward(self, x):
|
||||
if self.downsample is not None:
|
||||
@ -125,7 +125,7 @@ class Bottleneck(nn.Module):
|
||||
aa_layer(channels=planes, filt_size=3, stride=2))
|
||||
|
||||
reduction_chs = max(planes * self.expansion // 8, 64)
|
||||
self.se = SEModule(planes, reduction_channels=reduction_chs) if use_se else None
|
||||
self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None
|
||||
|
||||
self.conv3 = conv2d_iabn(
|
||||
planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity")
|
||||
|
@ -13,7 +13,7 @@ import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed
|
||||
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
@ -39,15 +39,6 @@ default_cfgs = dict(
|
||||
)
|
||||
|
||||
|
||||
class LayerNormBHWC(nn.LayerNorm):
|
||||
def __init__(self, dim):
|
||||
super().__init__(dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.layer_norm(
|
||||
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
class SpatialMlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None,
|
||||
act_layer=nn.GELU, drop=0., group=8, spatial_conv=False):
|
||||
@ -119,7 +110,7 @@ class Attention(nn.Module):
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4.,
|
||||
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormBHWC,
|
||||
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm2d,
|
||||
group=8, attn_disabled=False, spatial_conv=False):
|
||||
super().__init__()
|
||||
self.spatial_conv = spatial_conv
|
||||
@ -148,7 +139,7 @@ class Block(nn.Module):
|
||||
class Visformer(nn.Module):
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384,
|
||||
depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
||||
norm_layer=LayerNormBHWC, attn_stage='111', pos_embed=True, spatial_conv='111',
|
||||
norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111',
|
||||
vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
|
@ -1 +1 @@
|
||||
__version__ = '0.4.10'
|
||||
__version__ = '0.4.11'
|
||||
|
Loading…
x
Reference in New Issue
Block a user