mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Layer refactoring continues, ResNet downsample rewrite for proper dilation in 3x3 and avg_pool cases
* select_conv2d -> create_conv2d * added create_attn to create attention module from string/bool/module * factor padding helpers into own file, use in both conv2d_same and avg_pool2d_same * add some more test eca resnet variants * minor tweaks, naming, comments, consistency
This commit is contained in:
parent
a99ec4e7d1
commit
f902bcd54c
@ -28,7 +28,7 @@ from .feature_hooks import FeatureHooks
|
|||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .layers import SelectAdaptivePool2d
|
from .layers import SelectAdaptivePool2d
|
||||||
from timm.models.layers import select_conv2d
|
from timm.models.layers import create_conv2d
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
|
|
||||||
|
|
||||||
@ -220,7 +220,7 @@ class EfficientNet(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32,
|
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32,
|
||||||
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||||
pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
|
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
|
||||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
|
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
|
||||||
super(EfficientNet, self).__init__()
|
super(EfficientNet, self).__init__()
|
||||||
norm_kwargs = norm_kwargs or {}
|
norm_kwargs = norm_kwargs or {}
|
||||||
@ -232,21 +232,21 @@ class EfficientNet(nn.Module):
|
|||||||
|
|
||||||
# Stem
|
# Stem
|
||||||
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
||||||
self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
|
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
|
||||||
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
||||||
self.act1 = act_layer(inplace=True)
|
self.act1 = act_layer(inplace=True)
|
||||||
self._in_chs = stem_size
|
self._in_chs = stem_size
|
||||||
|
|
||||||
# Middle stages (IR/ER/DS Blocks)
|
# Middle stages (IR/ER/DS Blocks)
|
||||||
builder = EfficientNetBuilder(
|
builder = EfficientNetBuilder(
|
||||||
channel_multiplier, channel_divisor, channel_min, 32, pad_type, act_layer, se_kwargs,
|
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
|
||||||
norm_layer, norm_kwargs, drop_connect_rate, verbose=_DEBUG)
|
norm_layer, norm_kwargs, drop_connect_rate, verbose=_DEBUG)
|
||||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||||
self.feature_info = builder.features
|
self.feature_info = builder.features
|
||||||
self._in_chs = builder.in_chs
|
self._in_chs = builder.in_chs
|
||||||
|
|
||||||
# Head + Pooling
|
# Head + Pooling
|
||||||
self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type)
|
self.conv_head = create_conv2d(self._in_chs, self.num_features, 1, padding=pad_type)
|
||||||
self.bn2 = norm_layer(self.num_features, **norm_kwargs)
|
self.bn2 = norm_layer(self.num_features, **norm_kwargs)
|
||||||
self.act2 = act_layer(inplace=True)
|
self.act2 = act_layer(inplace=True)
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
@ -314,7 +314,7 @@ class EfficientNetFeatures(nn.Module):
|
|||||||
|
|
||||||
# Stem
|
# Stem
|
||||||
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
||||||
self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
|
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
|
||||||
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
||||||
self.act1 = act_layer(inplace=True)
|
self.act1 = act_layer(inplace=True)
|
||||||
self._in_chs = stem_size
|
self._in_chs = stem_size
|
||||||
|
@ -2,7 +2,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from .layers.activations import sigmoid
|
from .layers.activations import sigmoid
|
||||||
from .layers import select_conv2d
|
from .layers import create_conv2d
|
||||||
|
|
||||||
|
|
||||||
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
|
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
|
||||||
@ -129,7 +129,7 @@ class ConvBnAct(nn.Module):
|
|||||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||||
super(ConvBnAct, self).__init__()
|
super(ConvBnAct, self).__init__()
|
||||||
norm_kwargs = norm_kwargs or {}
|
norm_kwargs = norm_kwargs or {}
|
||||||
self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
|
self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
|
||||||
self.bn1 = norm_layer(out_chs, **norm_kwargs)
|
self.bn1 = norm_layer(out_chs, **norm_kwargs)
|
||||||
self.act1 = act_layer(inplace=True)
|
self.act1 = act_layer(inplace=True)
|
||||||
|
|
||||||
@ -162,7 +162,7 @@ class DepthwiseSeparableConv(nn.Module):
|
|||||||
self.has_pw_act = pw_act # activation after point-wise conv
|
self.has_pw_act = pw_act # activation after point-wise conv
|
||||||
self.drop_connect_rate = drop_connect_rate
|
self.drop_connect_rate = drop_connect_rate
|
||||||
|
|
||||||
self.conv_dw = select_conv2d(
|
self.conv_dw = create_conv2d(
|
||||||
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True)
|
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True)
|
||||||
self.bn1 = norm_layer(in_chs, **norm_kwargs)
|
self.bn1 = norm_layer(in_chs, **norm_kwargs)
|
||||||
self.act1 = act_layer(inplace=True)
|
self.act1 = act_layer(inplace=True)
|
||||||
@ -174,7 +174,7 @@ class DepthwiseSeparableConv(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.se = None
|
self.se = None
|
||||||
|
|
||||||
self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
|
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
|
||||||
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
||||||
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
|
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
|
||||||
|
|
||||||
@ -223,12 +223,12 @@ class InvertedResidual(nn.Module):
|
|||||||
self.drop_connect_rate = drop_connect_rate
|
self.drop_connect_rate = drop_connect_rate
|
||||||
|
|
||||||
# Point-wise expansion
|
# Point-wise expansion
|
||||||
self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
|
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
|
||||||
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
||||||
self.act1 = act_layer(inplace=True)
|
self.act1 = act_layer(inplace=True)
|
||||||
|
|
||||||
# Depth-wise convolution
|
# Depth-wise convolution
|
||||||
self.conv_dw = select_conv2d(
|
self.conv_dw = create_conv2d(
|
||||||
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
|
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
|
||||||
padding=pad_type, depthwise=True, **conv_kwargs)
|
padding=pad_type, depthwise=True, **conv_kwargs)
|
||||||
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
|
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
|
||||||
@ -242,7 +242,7 @@ class InvertedResidual(nn.Module):
|
|||||||
self.se = None
|
self.se = None
|
||||||
|
|
||||||
# Point-wise linear projection
|
# Point-wise linear projection
|
||||||
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
|
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
|
||||||
self.bn3 = norm_layer(out_chs, **norm_kwargs)
|
self.bn3 = norm_layer(out_chs, **norm_kwargs)
|
||||||
|
|
||||||
def feature_module(self, location):
|
def feature_module(self, location):
|
||||||
@ -356,7 +356,7 @@ class EdgeResidual(nn.Module):
|
|||||||
self.drop_connect_rate = drop_connect_rate
|
self.drop_connect_rate = drop_connect_rate
|
||||||
|
|
||||||
# Expansion convolution
|
# Expansion convolution
|
||||||
self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
|
self.conv_exp = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
|
||||||
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
||||||
self.act1 = act_layer(inplace=True)
|
self.act1 = act_layer(inplace=True)
|
||||||
|
|
||||||
@ -368,7 +368,7 @@ class EdgeResidual(nn.Module):
|
|||||||
self.se = None
|
self.se = None
|
||||||
|
|
||||||
# Point-wise linear projection
|
# Point-wise linear projection
|
||||||
self.conv_pwl = select_conv2d(
|
self.conv_pwl = create_conv2d(
|
||||||
mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
|
mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
|
||||||
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
|
from .layers import SEModule
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
from .resnet import ResNet, Bottleneck, BasicBlock
|
from .resnet import ResNet, Bottleneck, BasicBlock
|
||||||
@ -319,8 +320,8 @@ def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kw
|
|||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['gluon_seresnext50_32x4d']
|
default_cfg = default_cfgs['gluon_seresnext50_32x4d']
|
||||||
model = ResNet(
|
model = ResNet(
|
||||||
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, use_se=True,
|
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SEModule), **kwargs)
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
@ -333,8 +334,8 @@ def gluon_seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **k
|
|||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['gluon_seresnext101_32x4d']
|
default_cfg = default_cfgs['gluon_seresnext101_32x4d']
|
||||||
model = ResNet(
|
model = ResNet(
|
||||||
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, use_se=True,
|
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4,
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SEModule), **kwargs)
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
@ -346,9 +347,10 @@ def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **k
|
|||||||
"""Constructs a SEResNeXt-101-64x4d model.
|
"""Constructs a SEResNeXt-101-64x4d model.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['gluon_seresnext101_64x4d']
|
default_cfg = default_cfgs['gluon_seresnext101_64x4d']
|
||||||
|
block_args = dict(attn_layer=SEModule)
|
||||||
model = ResNet(
|
model = ResNet(
|
||||||
Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4, use_se=True,
|
Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4,
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
@ -360,10 +362,10 @@ def gluon_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|||||||
"""Constructs an SENet-154 model.
|
"""Constructs an SENet-154 model.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['gluon_senet154']
|
default_cfg = default_cfgs['gluon_senet154']
|
||||||
|
block_args = dict(attn_layer=SEModule)
|
||||||
model = ResNet(
|
model = ResNet(
|
||||||
Bottleneck, [3, 8, 36, 3], cardinality=64, base_width=4, use_se=True,
|
Bottleneck, [3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', down_kernel_size=3,
|
||||||
stem_type='deep', down_kernel_size=3, block_reduce_first=2,
|
block_reduce_first=2, num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
@ -1,8 +1,13 @@
|
|||||||
|
from .padding import get_padding
|
||||||
|
from .avg_pool2d_same import AvgPool2dSame
|
||||||
|
from .conv2d_same import Conv2dSame
|
||||||
from .conv_bn_act import ConvBnAct
|
from .conv_bn_act import ConvBnAct
|
||||||
from .mixed_conv2d import MixedConv2d
|
from .mixed_conv2d import MixedConv2d
|
||||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||||
from .select_conv2d import select_conv2d
|
from .create_conv2d import create_conv2d
|
||||||
|
from .create_attn import create_attn
|
||||||
from .selective_kernel import SelectiveKernelConv
|
from .selective_kernel import SelectiveKernelConv
|
||||||
|
from .se import SEModule
|
||||||
from .eca import EcaModule, CecaModule
|
from .eca import EcaModule, CecaModule
|
||||||
from .activations import *
|
from .activations import *
|
||||||
from .adaptive_avgmax_pool import \
|
from .adaptive_avgmax_pool import \
|
||||||
|
31
timm/models/layers/avg_pool2d_same.py
Normal file
31
timm/models/layers/avg_pool2d_same.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
""" AvgPool2d w/ Same Padding
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import List
|
||||||
|
import math
|
||||||
|
|
||||||
|
from .helpers import tup_pair
|
||||||
|
from .padding import pad_same
|
||||||
|
|
||||||
|
|
||||||
|
def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
|
||||||
|
ceil_mode: bool = False, count_include_pad: bool = True):
|
||||||
|
x = pad_same(x, kernel_size, stride)
|
||||||
|
return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
|
||||||
|
|
||||||
|
|
||||||
|
class AvgPool2dSame(nn.AvgPool2d):
|
||||||
|
""" Tensorflow like 'SAME' wrapper for 2D average pooling
|
||||||
|
"""
|
||||||
|
def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
|
||||||
|
kernel_size = tup_pair(kernel_size)
|
||||||
|
stride = tup_pair(stride)
|
||||||
|
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return avg_pool2d_same(
|
||||||
|
x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
|
@ -10,8 +10,8 @@ import torch
|
|||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from .helpers import tup_pair
|
||||||
from .conv2d_same import get_padding_value, conv2d_same
|
from .conv2d_same import get_padding_value, conv2d_same
|
||||||
from .conv_helpers import tup_pair
|
|
||||||
|
|
||||||
|
|
||||||
def get_condconv_initializer(initializer, num_experts, expert_shape):
|
def get_condconv_initializer(initializer, num_experts, expert_shape):
|
||||||
|
@ -8,26 +8,13 @@ import torch.nn.functional as F
|
|||||||
from typing import Union, List, Tuple, Optional, Callable
|
from typing import Union, List, Tuple, Optional, Callable
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from .conv_helpers import get_padding
|
from .padding import get_padding, pad_same, is_static_pad
|
||||||
|
|
||||||
|
|
||||||
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
|
|
||||||
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
|
|
||||||
|
|
||||||
|
|
||||||
def _calc_same_pad(i: int, k: int, s: int, d: int):
|
|
||||||
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
|
|
||||||
|
|
||||||
|
|
||||||
def conv2d_same(
|
def conv2d_same(
|
||||||
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
|
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
|
||||||
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
|
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
|
||||||
ih, iw = x.size()[-2:]
|
x = pad_same(x, weight.shape[-2:], stride, dilation)
|
||||||
kh, kw = weight.size()[-2:]
|
|
||||||
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
|
|
||||||
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
|
|
||||||
if pad_h > 0 or pad_w > 0:
|
|
||||||
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
|
||||||
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
|
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
|
||||||
|
|
||||||
|
|
||||||
@ -51,7 +38,7 @@ def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
|
|||||||
padding = padding.lower()
|
padding = padding.lower()
|
||||||
if padding == 'same':
|
if padding == 'same':
|
||||||
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
||||||
if _is_static_pad(kernel_size, **kwargs):
|
if is_static_pad(kernel_size, **kwargs):
|
||||||
# static case, no extra overhead
|
# static case, no extra overhead
|
||||||
padding = get_padding(kernel_size, **kwargs)
|
padding = get_padding(kernel_size, **kwargs)
|
||||||
else:
|
else:
|
||||||
|
@ -4,7 +4,7 @@ Hacked together by Ross Wightman
|
|||||||
"""
|
"""
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
from timm.models.layers.conv_helpers import get_padding
|
from timm.models.layers import get_padding
|
||||||
|
|
||||||
|
|
||||||
class ConvBnAct(nn.Module):
|
class ConvBnAct(nn.Module):
|
||||||
|
30
timm/models/layers/create_attn.py
Normal file
30
timm/models/layers/create_attn.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
""" Select AttentionFactory Method
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from .se import SEModule
|
||||||
|
from .eca import EcaModule, CecaModule
|
||||||
|
|
||||||
|
|
||||||
|
def create_attn(attn_type, channels, **kwargs):
|
||||||
|
module_cls = None
|
||||||
|
if attn_type is not None:
|
||||||
|
if isinstance(attn_type, str):
|
||||||
|
attn_type = attn_type.lower()
|
||||||
|
if attn_type == 'se':
|
||||||
|
module_cls = SEModule
|
||||||
|
elif attn_type == 'eca':
|
||||||
|
module_cls = EcaModule
|
||||||
|
elif attn_type == 'eca':
|
||||||
|
module_cls = CecaModule
|
||||||
|
else:
|
||||||
|
assert False, "Invalid attn module (%s)" % attn_type
|
||||||
|
elif isinstance(attn_type, bool):
|
||||||
|
if attn_type:
|
||||||
|
module_cls = SEModule
|
||||||
|
else:
|
||||||
|
module_cls = attn_type
|
||||||
|
if module_cls is not None:
|
||||||
|
return module_cls(channels, **kwargs)
|
||||||
|
return None
|
@ -1,4 +1,4 @@
|
|||||||
""" Select Conv2d Factory Method
|
""" Create Conv2d Factory Method
|
||||||
|
|
||||||
Hacked together by Ross Wightman
|
Hacked together by Ross Wightman
|
||||||
"""
|
"""
|
||||||
@ -8,7 +8,7 @@ from .cond_conv2d import CondConv2d
|
|||||||
from .conv2d_same import create_conv2d_pad
|
from .conv2d_same import create_conv2d_pad
|
||||||
|
|
||||||
|
|
||||||
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
|
def create_conv2d(in_chs, out_chs, kernel_size, **kwargs):
|
||||||
""" Select a 2d convolution implementation based on arguments
|
""" Select a 2d convolution implementation based on arguments
|
||||||
Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.
|
Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.
|
||||||
|
|
@ -1,3 +1,9 @@
|
|||||||
|
""" DropBlock, DropPath
|
||||||
|
|
||||||
|
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -6,6 +12,8 @@ import math
|
|||||||
|
|
||||||
|
|
||||||
def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noise=False):
|
def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noise=False):
|
||||||
|
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||||
|
"""
|
||||||
_, _, height, width = x.shape
|
_, _, height, width = x.shape
|
||||||
total_size = width * height
|
total_size = width * height
|
||||||
clipped_block_size = min(block_size, min(width, height))
|
clipped_block_size = min(block_size, min(width, height))
|
||||||
@ -24,7 +32,7 @@ def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noi
|
|||||||
block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).float()
|
block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).float()
|
||||||
block_mask = -F.max_pool2d(
|
block_mask = -F.max_pool2d(
|
||||||
-block_mask,
|
-block_mask,
|
||||||
kernel_size=clipped_block_size, # block_size,
|
kernel_size=clipped_block_size, # block_size, ???
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=clipped_block_size // 2)
|
padding=clipped_block_size // 2)
|
||||||
|
|
||||||
@ -58,7 +66,8 @@ class DropBlock2d(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def drop_path(x, drop_prob=0.):
|
def drop_path(x, drop_prob=0.):
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in residual blocks)."""
|
"""Drop paths (Stochastic Depth) per sample (when applied in residual blocks).
|
||||||
|
"""
|
||||||
keep_prob = 1 - drop_prob
|
keep_prob = 1 - drop_prob
|
||||||
random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device)
|
random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device)
|
||||||
random_tensor.floor_() # binarize
|
random_tensor.floor_() # binarize
|
||||||
@ -67,6 +76,8 @@ def drop_path(x, drop_prob=0.):
|
|||||||
|
|
||||||
|
|
||||||
class DropPath(nn.ModuleDict):
|
class DropPath(nn.ModuleDict):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in residual blocks).
|
||||||
|
"""
|
||||||
def __init__(self, drop_prob=None):
|
def __init__(self, drop_prob=None):
|
||||||
super(DropPath, self).__init__()
|
super(DropPath, self).__init__()
|
||||||
self.drop_prob = drop_prob
|
self.drop_prob = drop_prob
|
||||||
|
@ -47,19 +47,20 @@ class EcaModule(nn.Module):
|
|||||||
gamma, beta: when channel is given parameters of mapping function
|
gamma, beta: when channel is given parameters of mapping function
|
||||||
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
|
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
|
||||||
(default=None. if channel size not given, use k_size given for kernel size.)
|
(default=None. if channel size not given, use k_size given for kernel size.)
|
||||||
k_size: Adaptive selection of kernel size (default=3)
|
kernel_size: Adaptive selection of kernel size (default=3)
|
||||||
"""
|
"""
|
||||||
def __init__(self, channel=None, k_size=3, gamma=2, beta=1):
|
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
|
||||||
super(EcaModule, self).__init__()
|
super(EcaModule, self).__init__()
|
||||||
assert k_size % 2 == 1
|
assert kernel_size % 2 == 1
|
||||||
|
|
||||||
if channel is not None:
|
if channels is not None:
|
||||||
t = int(abs(math.log(channel, 2)+beta) / gamma)
|
t = int(abs(math.log(channels, 2) + beta) / gamma)
|
||||||
k_size = t if t % 2 else t + 1
|
kernel_size = max(t if t % 2 else t + 1, 3)
|
||||||
|
|
||||||
|
print('florg', kernel_size)
|
||||||
|
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
|
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
|
||||||
self.sigmoid = nn.Sigmoid()
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Feature descriptor on the global spatial information
|
# Feature descriptor on the global spatial information
|
||||||
@ -69,7 +70,7 @@ class EcaModule(nn.Module):
|
|||||||
# Two different branches of ECA module
|
# Two different branches of ECA module
|
||||||
y = self.conv(y)
|
y = self.conv(y)
|
||||||
# Multi-scale information fusion
|
# Multi-scale information fusion
|
||||||
y = self.sigmoid(y.view(x.shape[0], -1, 1, 1))
|
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
|
||||||
return x * y.expand_as(x)
|
return x * y.expand_as(x)
|
||||||
|
|
||||||
|
|
||||||
@ -93,22 +94,21 @@ class CecaModule(nn.Module):
|
|||||||
k_size: Adaptive selection of kernel size (default=3)
|
k_size: Adaptive selection of kernel size (default=3)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channel=None, k_size=3, gamma=2, beta=1):
|
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
|
||||||
super(CecaModule, self).__init__()
|
super(CecaModule, self).__init__()
|
||||||
assert k_size % 2 == 1
|
assert kernel_size % 2 == 1
|
||||||
|
|
||||||
if channel is not None:
|
if channels is not None:
|
||||||
t = int(abs(math.log(channel, 2)+beta) / gamma)
|
t = int(abs(math.log(channels, 2) + beta) / gamma)
|
||||||
k_size = t if t % 2 else t + 1
|
kernel_size = max(t if t % 2 else t + 1, 3)
|
||||||
|
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
#pytorch circular padding mode is buggy as of pytorch 1.4
|
#pytorch circular padding mode is buggy as of pytorch 1.4
|
||||||
#see https://github.com/pytorch/pytorch/pull/17240
|
#see https://github.com/pytorch/pytorch/pull/17240
|
||||||
|
|
||||||
#implement manual circular padding
|
#implement manual circular padding
|
||||||
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=0, bias=False)
|
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
|
||||||
self.padding = (k_size - 1) // 2
|
self.padding = (kernel_size - 1) // 2
|
||||||
self.sigmoid = nn.Sigmoid()
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Feature descriptor on the global spatial information
|
# Feature descriptor on the global spatial information
|
||||||
@ -121,6 +121,6 @@ class CecaModule(nn.Module):
|
|||||||
y = self.conv(y)
|
y = self.conv(y)
|
||||||
|
|
||||||
# Multi-scale information fusion
|
# Multi-scale information fusion
|
||||||
y = self.sigmoid(y.view(x.shape[0], -1, 1, 1))
|
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
|
||||||
|
|
||||||
return x * y.expand_as(x)
|
return x * y.expand_as(x)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
""" Common Helpers
|
""" Layer/Module Helpers
|
||||||
|
|
||||||
Hacked together by Ross Wightman
|
Hacked together by Ross Wightman
|
||||||
"""
|
"""
|
||||||
@ -21,7 +21,7 @@ tup_triple = _ntuple(3)
|
|||||||
tup_quadruple = _ntuple(4)
|
tup_quadruple = _ntuple(4)
|
||||||
|
|
||||||
|
|
||||||
# Calculate symmetric padding for a convolution
|
|
||||||
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
|
|
||||||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
|
||||||
return padding
|
|
33
timm/models/layers/padding.py
Normal file
33
timm/models/layers/padding.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
""" Padding Helpers
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
# Calculate symmetric padding for a convolution
|
||||||
|
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
|
||||||
|
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||||
|
return padding
|
||||||
|
|
||||||
|
|
||||||
|
# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
|
||||||
|
def get_same_padding(x: int, k: int, s: int, d: int):
|
||||||
|
return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
|
||||||
|
|
||||||
|
|
||||||
|
# Can SAME padding for given args be done statically?
|
||||||
|
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
|
||||||
|
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
|
||||||
|
|
||||||
|
|
||||||
|
# Dynamically pad input x with 'SAME' padding for conv with specified args
|
||||||
|
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1)):
|
||||||
|
ih, iw = x.size()[-2:]
|
||||||
|
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
|
||||||
|
if pad_h > 0 or pad_w > 0:
|
||||||
|
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
||||||
|
return x
|
21
timm/models/layers/se.py
Normal file
21
timm/models/layers/se.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from torch import nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class SEModule(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
|
||||||
|
super(SEModule, self).__init__()
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
reduction_channels = max(channels // reduction, 8)
|
||||||
|
self.fc1 = nn.Conv2d(
|
||||||
|
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
|
||||||
|
self.act = act_layer(inplace=True)
|
||||||
|
self.fc2 = nn.Conv2d(
|
||||||
|
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_se = self.avg_pool(x)
|
||||||
|
x_se = self.fc1(x_se)
|
||||||
|
x_se = self.act(x_se)
|
||||||
|
x_se = self.fc2(x_se)
|
||||||
|
return x * x_se.sigmoid()
|
@ -34,6 +34,8 @@ class TestTimePoolHead(nn.Module):
|
|||||||
|
|
||||||
def apply_test_time_pool(model, config, args):
|
def apply_test_time_pool(model, config, args):
|
||||||
test_time_pool = False
|
test_time_pool = False
|
||||||
|
if not hasattr(model, 'default_cfg') or not model.default_cfg:
|
||||||
|
return model, False
|
||||||
if not args.no_test_pool and \
|
if not args.no_test_pool and \
|
||||||
config['input_size'][-1] > model.default_cfg['input_size'][-1] and \
|
config['input_size'][-1] > model.default_cfg['input_size'][-1] and \
|
||||||
config['input_size'][-2] > model.default_cfg['input_size'][-2]:
|
config['input_size'][-2] > model.default_cfg['input_size'][-2]:
|
||||||
|
@ -11,7 +11,7 @@ Hacked together by Ross Wightman
|
|||||||
from .efficientnet_builder import *
|
from .efficientnet_builder import *
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .layers import SelectAdaptivePool2d, select_conv2d
|
from .layers import SelectAdaptivePool2d, create_conv2d
|
||||||
from .layers.activations import HardSwish, hard_sigmoid
|
from .layers.activations import HardSwish, hard_sigmoid
|
||||||
from .feature_hooks import FeatureHooks
|
from .feature_hooks import FeatureHooks
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
@ -82,7 +82,7 @@ class MobileNetV3(nn.Module):
|
|||||||
|
|
||||||
# Stem
|
# Stem
|
||||||
stem_size = round_channels(stem_size, channel_multiplier)
|
stem_size = round_channels(stem_size, channel_multiplier)
|
||||||
self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
|
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
|
||||||
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
||||||
self.act1 = act_layer(inplace=True)
|
self.act1 = act_layer(inplace=True)
|
||||||
self._in_chs = stem_size
|
self._in_chs = stem_size
|
||||||
@ -97,7 +97,7 @@ class MobileNetV3(nn.Module):
|
|||||||
|
|
||||||
# Head + Pooling
|
# Head + Pooling
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
|
self.conv_head = create_conv2d(self._in_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
|
||||||
self.act2 = act_layer(inplace=True)
|
self.act2 = act_layer(inplace=True)
|
||||||
|
|
||||||
# Classifier
|
# Classifier
|
||||||
@ -162,7 +162,7 @@ class MobileNetV3Features(nn.Module):
|
|||||||
|
|
||||||
# Stem
|
# Stem
|
||||||
stem_size = round_channels(stem_size, channel_multiplier)
|
stem_size = round_channels(stem_size, channel_multiplier)
|
||||||
self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
|
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
|
||||||
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
||||||
self.act1 = act_layer(inplace=True)
|
self.act1 = act_layer(inplace=True)
|
||||||
self._in_chs = stem_size
|
self._in_chs = stem_size
|
||||||
|
@ -8,10 +8,10 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .resnet import ResNet, SEModule
|
from .resnet import ResNet
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .layers import SelectAdaptivePool2d
|
from .layers import SEModule
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
@ -53,8 +53,8 @@ class Bottle2neck(nn.Module):
|
|||||||
expansion = 4
|
expansion = 4
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||||
cardinality=1, base_width=26, scale=4, use_se=False,
|
cardinality=1, base_width=26, scale=4, dilation=1, first_dilation=None,
|
||||||
act_layer=nn.ReLU, norm_layer=None, dilation=1, first_dilation=None, **_):
|
act_layer=nn.ReLU, norm_layer=None, attn_layer=None, **_):
|
||||||
super(Bottle2neck, self).__init__()
|
super(Bottle2neck, self).__init__()
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.is_first = stride > 1 or downsample is not None
|
self.is_first = stride > 1 or downsample is not None
|
||||||
@ -82,7 +82,7 @@ class Bottle2neck(nn.Module):
|
|||||||
|
|
||||||
self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False)
|
self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False)
|
||||||
self.bn3 = norm_layer(outplanes)
|
self.bn3 = norm_layer(outplanes)
|
||||||
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
self.se = attn_layer(outplanes) if attn_layer is not None else None
|
||||||
|
|
||||||
self.relu = act_layer(inplace=True)
|
self.relu = act_layer(inplace=True)
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
|
@ -7,13 +7,12 @@ ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered ste
|
|||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .layers import EcaModule, SelectAdaptivePool2d, DropBlock2d, DropPath
|
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
|
|
||||||
@ -103,7 +102,8 @@ default_cfgs = {
|
|||||||
'ecaresnext26tn_32x4d': _cfg(
|
'ecaresnext26tn_32x4d': _cfg(
|
||||||
url='',
|
url='',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
|
'ecaresnet18': _cfg(),
|
||||||
|
'ecaresnet50': _cfg(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -112,32 +112,12 @@ def get_padding(kernel_size, stride, dilation=1):
|
|||||||
return padding
|
return padding
|
||||||
|
|
||||||
|
|
||||||
class SEModule(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, channels, reduction_channels):
|
|
||||||
super(SEModule, self).__init__()
|
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
||||||
self.fc1 = nn.Conv2d(
|
|
||||||
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
self.fc2 = nn.Conv2d(
|
|
||||||
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x_se = self.avg_pool(x)
|
|
||||||
x_se = self.fc1(x_se)
|
|
||||||
x_se = self.relu(x_se)
|
|
||||||
x_se = self.fc2(x_se)
|
|
||||||
return x * x_se.sigmoid()
|
|
||||||
|
|
||||||
|
|
||||||
class BasicBlock(nn.Module):
|
class BasicBlock(nn.Module):
|
||||||
__constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat
|
|
||||||
expansion = 1
|
expansion = 1
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False,
|
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
||||||
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
||||||
drop_block=None, drop_path=None):
|
attn_layer=None, drop_block=None, drop_path=None):
|
||||||
super(BasicBlock, self).__init__()
|
super(BasicBlock, self).__init__()
|
||||||
|
|
||||||
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
||||||
@ -155,7 +135,7 @@ class BasicBlock(nn.Module):
|
|||||||
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
|
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
|
||||||
self.bn2 = norm_layer(outplanes)
|
self.bn2 = norm_layer(outplanes)
|
||||||
|
|
||||||
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
self.se = create_attn(attn_layer, outplanes)
|
||||||
|
|
||||||
self.act2 = act_layer(inplace=True)
|
self.act2 = act_layer(inplace=True)
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
@ -199,9 +179,9 @@ class Bottleneck(nn.Module):
|
|||||||
__constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat
|
__constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat
|
||||||
expansion = 4
|
expansion = 4
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False,
|
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
||||||
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
||||||
drop_block=None, drop_path=None):
|
attn_layer=None, drop_block=None, drop_path=None):
|
||||||
super(Bottleneck, self).__init__()
|
super(Bottleneck, self).__init__()
|
||||||
|
|
||||||
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
||||||
@ -220,7 +200,7 @@ class Bottleneck(nn.Module):
|
|||||||
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
||||||
self.bn3 = norm_layer(outplanes)
|
self.bn3 = norm_layer(outplanes)
|
||||||
|
|
||||||
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
self.se = create_attn(attn_layer, outplanes)
|
||||||
|
|
||||||
self.act3 = act_layer(inplace=True)
|
self.act3 = act_layer(inplace=True)
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
@ -266,6 +246,37 @@ class Bottleneck(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def downsample_conv(
|
||||||
|
in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
|
||||||
|
norm_layer = norm_layer or nn.BatchNorm2d
|
||||||
|
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
|
||||||
|
first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
|
||||||
|
p = get_padding(kernel_size, stride, first_dilation)
|
||||||
|
|
||||||
|
return nn.Sequential(*[
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False),
|
||||||
|
norm_layer(out_channels)
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def downsample_avg(
|
||||||
|
in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
|
||||||
|
norm_layer = norm_layer or nn.BatchNorm2d
|
||||||
|
avg_stride = stride if dilation == 1 else 1
|
||||||
|
if stride == 1 and dilation == 1:
|
||||||
|
pool = nn.Identity()
|
||||||
|
else:
|
||||||
|
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
|
||||||
|
pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
|
||||||
|
|
||||||
|
return nn.Sequential(*[
|
||||||
|
pool,
|
||||||
|
nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False),
|
||||||
|
norm_layer(out_channels)
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
class ResNet(nn.Module):
|
class ResNet(nn.Module):
|
||||||
"""ResNet / ResNeXt / SE-ResNeXt / SE-Net
|
"""ResNet / ResNeXt / SE-ResNeXt / SE-Net
|
||||||
|
|
||||||
@ -307,8 +318,6 @@ class ResNet(nn.Module):
|
|||||||
Number of classification classes.
|
Number of classification classes.
|
||||||
in_chans : int, default 3
|
in_chans : int, default 3
|
||||||
Number of input (color) channels.
|
Number of input (color) channels.
|
||||||
use_se : bool, default False
|
|
||||||
Enable Squeeze-Excitation module in blocks
|
|
||||||
cardinality : int, default 1
|
cardinality : int, default 1
|
||||||
Number of convolution groups for 3x3 conv in Bottleneck.
|
Number of convolution groups for 3x3 conv in Bottleneck.
|
||||||
base_width : int, default 64
|
base_width : int, default 64
|
||||||
@ -337,7 +346,7 @@ class ResNet(nn.Module):
|
|||||||
global_pool : str, default 'avg'
|
global_pool : str, default 'avg'
|
||||||
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
|
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
|
||||||
"""
|
"""
|
||||||
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, use_eca=False,
|
def __init__(self, block, layers, num_classes=1000, in_chans=3,
|
||||||
cardinality=1, base_width=64, stem_width=64, stem_type='',
|
cardinality=1, base_width=64, stem_width=64, stem_type='',
|
||||||
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
|
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
|
||||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0.,
|
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0.,
|
||||||
@ -385,14 +394,14 @@ class ResNet(nn.Module):
|
|||||||
dilations[2:4] = [2, 4]
|
dilations[2:4] = [2, 4]
|
||||||
else:
|
else:
|
||||||
assert output_stride == 32
|
assert output_stride == 32
|
||||||
llargs = list(zip(channels, layers, strides, dilations))
|
layer_args = list(zip(channels, layers, strides, dilations))
|
||||||
lkwargs = dict(
|
layer_kwargs = dict(
|
||||||
use_se=use_se, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
|
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
|
||||||
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
|
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
|
||||||
self.layer1 = self._make_layer(block, *llargs[0], **lkwargs)
|
self.layer1 = self._make_layer(block, *layer_args[0], **layer_kwargs)
|
||||||
self.layer2 = self._make_layer(block, *llargs[1], **lkwargs)
|
self.layer2 = self._make_layer(block, *layer_args[1], **layer_kwargs)
|
||||||
self.layer3 = self._make_layer(block, drop_block=db_3, *llargs[2], **lkwargs)
|
self.layer3 = self._make_layer(block, drop_block=db_3, *layer_args[2], **layer_kwargs)
|
||||||
self.layer4 = self._make_layer(block, drop_block=db_4, *llargs[3], **lkwargs)
|
self.layer4 = self._make_layer(block, drop_block=db_4, *layer_args[3], **layer_kwargs)
|
||||||
|
|
||||||
# Head (Pooling and Classifier)
|
# Head (Pooling and Classifier)
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
@ -411,31 +420,21 @@ class ResNet(nn.Module):
|
|||||||
m.zero_init_last_bn()
|
m.zero_init_last_bn()
|
||||||
|
|
||||||
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
|
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
|
||||||
use_se=False, use_eca=False,avg_down=False, down_kernel_size=1, **kwargs):
|
avg_down=False, down_kernel_size=1, **kwargs):
|
||||||
norm_layer = kwargs.get('norm_layer')
|
|
||||||
downsample = None
|
downsample = None
|
||||||
down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size
|
|
||||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
||||||
downsample_padding = get_padding(down_kernel_size, stride)
|
|
||||||
downsample_layers = []
|
|
||||||
conv_stride = stride
|
|
||||||
if avg_down:
|
|
||||||
avg_stride = stride if dilation == 1 else 1
|
|
||||||
conv_stride = 1
|
|
||||||
downsample_layers = [nn.AvgPool2d(avg_stride, avg_stride, ceil_mode=True, count_include_pad=False)]
|
|
||||||
downsample_layers += [
|
|
||||||
nn.Conv2d(self.inplanes, planes * block.expansion, down_kernel_size,
|
|
||||||
stride=conv_stride, padding=downsample_padding, bias=False),
|
|
||||||
norm_layer(planes * block.expansion)]
|
|
||||||
downsample = nn.Sequential(*downsample_layers)
|
|
||||||
|
|
||||||
first_dilation = 1 if dilation in (1, 2) else 2
|
first_dilation = 1 if dilation in (1, 2) else 2
|
||||||
bkwargs = dict(
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
downsample_args = dict(
|
||||||
|
in_channels=self.inplanes, out_channels=planes * block.expansion, kernel_size=down_kernel_size,
|
||||||
|
stride=stride, dilation=dilation, first_dilation=first_dilation, norm_layer=kwargs.get('norm_layer'))
|
||||||
|
downsample = downsample_avg(**downsample_args) if avg_down else downsample_conv(**downsample_args)
|
||||||
|
|
||||||
|
block_kwargs = dict(
|
||||||
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
|
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
|
||||||
dilation=dilation, use_se=use_se, **kwargs)
|
dilation=dilation, **kwargs)
|
||||||
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **bkwargs)]
|
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)]
|
||||||
self.inplanes = planes * block.expansion
|
self.inplanes = planes * block.expansion
|
||||||
layers += [block(self.inplanes, planes, **bkwargs) for _ in range(1, blocks)]
|
layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)]
|
||||||
|
|
||||||
return nn.Sequential(*layers)
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
@ -936,9 +935,8 @@ def seresnext26d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
|||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['seresnext26d_32x4d']
|
default_cfg = default_cfgs['seresnext26d_32x4d']
|
||||||
model = ResNet(
|
model = ResNet(
|
||||||
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
|
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, stem_type='deep', avg_down=True,
|
||||||
stem_width=32, stem_type='deep', avg_down=True, use_se=True,
|
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs)
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
@ -954,8 +952,8 @@ def seresnext26t_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
|||||||
default_cfg = default_cfgs['seresnext26t_32x4d']
|
default_cfg = default_cfgs['seresnext26t_32x4d']
|
||||||
model = ResNet(
|
model = ResNet(
|
||||||
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
|
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
|
||||||
stem_width=32, stem_type='deep_tiered', avg_down=True, use_se=True,
|
stem_width=32, stem_type='deep_tiered', avg_down=True,
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs)
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
@ -971,25 +969,55 @@ def seresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs
|
|||||||
default_cfg = default_cfgs['seresnext26tn_32x4d']
|
default_cfg = default_cfgs['seresnext26tn_32x4d']
|
||||||
model = ResNet(
|
model = ResNet(
|
||||||
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
|
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
|
||||||
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, use_se=True,
|
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True,
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs)
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def ecaresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def ecaresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
"""Constructs a eca-ResNeXt-26-TN model.
|
"""Constructs an ECA-ResNeXt-26-TN model.
|
||||||
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
|
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
|
||||||
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
|
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
|
||||||
this model replaces SE module with the ECA module
|
this model replaces SE module with the ECA module
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['ecaresnext26tn_32x4d']
|
default_cfg = default_cfgs['ecaresnext26tn_32x4d']
|
||||||
|
block_args = dict(attn_layer='eca')
|
||||||
model = ResNet(
|
model = ResNet(
|
||||||
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
|
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
|
||||||
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, use_eca=True,
|
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True,
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def ecaresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
|
""" Constructs an ECA-ResNet-18 model.
|
||||||
|
"""
|
||||||
|
default_cfg = default_cfgs['ecaresnet18']
|
||||||
|
block_args = dict(attn_layer='eca')
|
||||||
|
model = ResNet(
|
||||||
|
BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def ecaresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
|
"""Constructs an ECA-ResNet-50 model.
|
||||||
|
"""
|
||||||
|
default_cfg = default_cfgs['ecaresnet50']
|
||||||
|
block_args = dict(attn_layer='eca')
|
||||||
|
model = ResNet(
|
||||||
|
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
@ -4,8 +4,8 @@ from torch import nn as nn
|
|||||||
|
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .layers import SelectiveKernelConv, ConvBnAct
|
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
|
||||||
from .resnet import ResNet, SEModule
|
from .resnet import ResNet
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
|
|
||||||
@ -33,8 +33,8 @@ class SelectiveKernelBasic(nn.Module):
|
|||||||
expansion = 1
|
expansion = 1
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
||||||
use_se=False, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
|
sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
|
||||||
drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None):
|
||||||
super(SelectiveKernelBasic, self).__init__()
|
super(SelectiveKernelBasic, self).__init__()
|
||||||
|
|
||||||
sk_kwargs = sk_kwargs or {}
|
sk_kwargs = sk_kwargs or {}
|
||||||
@ -42,7 +42,7 @@ class SelectiveKernelBasic(nn.Module):
|
|||||||
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
||||||
assert base_width == 64, 'BasicBlock doest not support changing base width'
|
assert base_width == 64, 'BasicBlock doest not support changing base width'
|
||||||
first_planes = planes // reduce_first
|
first_planes = planes // reduce_first
|
||||||
out_planes = planes * self.expansion
|
outplanes = planes * self.expansion
|
||||||
first_dilation = first_dilation or dilation
|
first_dilation = first_dilation or dilation
|
||||||
|
|
||||||
_selective_first = True # FIXME temporary, for experiments
|
_selective_first = True # FIXME temporary, for experiments
|
||||||
@ -51,14 +51,14 @@ class SelectiveKernelBasic(nn.Module):
|
|||||||
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
|
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
|
||||||
conv_kwargs['act_layer'] = None
|
conv_kwargs['act_layer'] = None
|
||||||
self.conv2 = ConvBnAct(
|
self.conv2 = ConvBnAct(
|
||||||
first_planes, out_planes, kernel_size=3, dilation=dilation, **conv_kwargs)
|
first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs)
|
||||||
else:
|
else:
|
||||||
self.conv1 = ConvBnAct(
|
self.conv1 = ConvBnAct(
|
||||||
inplanes, first_planes, kernel_size=3, stride=stride, dilation=first_dilation, **conv_kwargs)
|
inplanes, first_planes, kernel_size=3, stride=stride, dilation=first_dilation, **conv_kwargs)
|
||||||
conv_kwargs['act_layer'] = None
|
conv_kwargs['act_layer'] = None
|
||||||
self.conv2 = SelectiveKernelConv(
|
self.conv2 = SelectiveKernelConv(
|
||||||
first_planes, out_planes, dilation=dilation, **conv_kwargs, **sk_kwargs)
|
first_planes, outplanes, dilation=dilation, **conv_kwargs, **sk_kwargs)
|
||||||
self.se = SEModule(out_planes, planes // 4) if use_se else None
|
self.se = create_attn(attn_layer, outplanes)
|
||||||
self.act = act_layer(inplace=True)
|
self.act = act_layer(inplace=True)
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
@ -88,17 +88,15 @@ class SelectiveKernelBottleneck(nn.Module):
|
|||||||
expansion = 4
|
expansion = 4
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||||
cardinality=1, base_width=64, use_se=False, sk_kwargs=None,
|
cardinality=1, base_width=64, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
|
||||||
reduce_first=1, dilation=1, first_dilation=None,
|
drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None):
|
||||||
drop_block=None, drop_path=None,
|
|
||||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
|
||||||
super(SelectiveKernelBottleneck, self).__init__()
|
super(SelectiveKernelBottleneck, self).__init__()
|
||||||
|
|
||||||
sk_kwargs = sk_kwargs or {}
|
sk_kwargs = sk_kwargs or {}
|
||||||
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
|
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
|
||||||
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
||||||
first_planes = width // reduce_first
|
first_planes = width // reduce_first
|
||||||
out_planes = planes * self.expansion
|
outplanes = planes * self.expansion
|
||||||
first_dilation = first_dilation or dilation
|
first_dilation = first_dilation or dilation
|
||||||
|
|
||||||
self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
|
self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
|
||||||
@ -106,8 +104,8 @@ class SelectiveKernelBottleneck(nn.Module):
|
|||||||
first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality,
|
first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality,
|
||||||
**conv_kwargs, **sk_kwargs)
|
**conv_kwargs, **sk_kwargs)
|
||||||
conv_kwargs['act_layer'] = None
|
conv_kwargs['act_layer'] = None
|
||||||
self.conv3 = ConvBnAct(width, out_planes, kernel_size=1, **conv_kwargs)
|
self.conv3 = ConvBnAct(width, outplanes, kernel_size=1, **conv_kwargs)
|
||||||
self.se = SEModule(out_planes, planes // 4) if use_se else None
|
self.se = create_attn(attn_layer, outplanes)
|
||||||
self.act = act_layer(inplace=True)
|
self.act = act_layer(inplace=True)
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
Loading…
x
Reference in New Issue
Block a user