mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add ConvBnAct layer to parallel integrated SelectKernelConv, add support for DropPath and DropBlock to ResNet base and SK blocks
This commit is contained in:
parent
cefc9b7761
commit
9f11b4e8a2
@ -271,11 +271,36 @@ def _kernel_valid(k):
|
|||||||
assert k >= 3 and k % 2
|
assert k >= 3 and k % 2
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBnAct(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1,
|
||||||
|
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||||
|
super(ConvBnAct, self).__init__()
|
||||||
|
padding = _get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
|
||||||
|
padding=padding, dilation=dilation, groups=groups, bias=False)
|
||||||
|
self.bn = norm_layer(out_channels)
|
||||||
|
self.drop_block = drop_block
|
||||||
|
if act_layer is not None:
|
||||||
|
self.act = act_layer(inplace=True)
|
||||||
|
else:
|
||||||
|
self.act = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
if self.drop_block is not None:
|
||||||
|
x = self.drop_block(x)
|
||||||
|
if self.act is not None:
|
||||||
|
x = self.act(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SelectiveKernelConv(nn.Module):
|
class SelectiveKernelConv(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
|
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
|
||||||
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
|
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
|
||||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||||
super(SelectiveKernelConv, self).__init__()
|
super(SelectiveKernelConv, self).__init__()
|
||||||
kernel_size = kernel_size or [3, 5]
|
kernel_size = kernel_size or [3, 5]
|
||||||
_kernel_valid(kernel_size)
|
_kernel_valid(kernel_size)
|
||||||
@ -297,19 +322,15 @@ class SelectiveKernelConv(nn.Module):
|
|||||||
out_channels = out_channels // num_paths
|
out_channels = out_channels // num_paths
|
||||||
groups = min(out_channels, groups)
|
groups = min(out_channels, groups)
|
||||||
|
|
||||||
self.paths = nn.ModuleList()
|
conv_kwargs = dict(
|
||||||
for k, d in zip(kernel_size, dilation):
|
stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
|
||||||
p = _get_padding(k, stride, d)
|
self.paths = nn.ModuleList([
|
||||||
self.paths.append(nn.Sequential(OrderedDict([
|
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
|
||||||
('conv', nn.Conv2d(
|
for k, d in zip(kernel_size, dilation)])
|
||||||
in_channels, out_channels, kernel_size=k, stride=stride, padding=p,
|
|
||||||
dilation=d, groups=groups, bias=False)),
|
|
||||||
('bn', norm_layer(out_channels)),
|
|
||||||
('act', act_layer(inplace=True))
|
|
||||||
])))
|
|
||||||
|
|
||||||
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
|
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
|
||||||
self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels)
|
self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels)
|
||||||
|
self.drop_block = drop_block
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.split_input:
|
if self.split_input:
|
||||||
|
@ -14,6 +14,7 @@ import torch.nn.functional as F
|
|||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||||
|
from .nn_ops import DropBlock2d, DropPath
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
|
|
||||||
@ -132,7 +133,8 @@ class BasicBlock(nn.Module):
|
|||||||
expansion = 1
|
expansion = 1
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False,
|
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False,
|
||||||
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
||||||
|
drop_block=None, drop_path=None):
|
||||||
super(BasicBlock, self).__init__()
|
super(BasicBlock, self).__init__()
|
||||||
|
|
||||||
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
||||||
@ -181,7 +183,8 @@ class Bottleneck(nn.Module):
|
|||||||
expansion = 4
|
expansion = 4
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False,
|
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False,
|
||||||
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
||||||
|
drop_block=None, drop_path=None):
|
||||||
super(Bottleneck, self).__init__()
|
super(Bottleneck, self).__init__()
|
||||||
|
|
||||||
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
||||||
@ -305,8 +308,8 @@ class ResNet(nn.Module):
|
|||||||
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
|
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
|
||||||
cardinality=1, base_width=64, stem_width=64, stem_type='',
|
cardinality=1, base_width=64, stem_width=64, stem_type='',
|
||||||
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
|
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
|
||||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg',
|
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0.,
|
||||||
zero_init_last_bn=True, block_args=None):
|
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
|
||||||
block_args = block_args or dict()
|
block_args = block_args or dict()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
deep_stem = 'deep' in stem_type
|
deep_stem = 'deep' in stem_type
|
||||||
@ -338,6 +341,9 @@ class ResNet(nn.Module):
|
|||||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
# Feature Blocks
|
# Feature Blocks
|
||||||
|
dp = DropPath(drop_path_rate) if drop_block_rate else None
|
||||||
|
db_3 = DropBlock2d(drop_block_rate, 7, 0.25) if drop_block_rate else None
|
||||||
|
db_4 = DropBlock2d(drop_block_rate, 7, 1.00) if drop_block_rate else None
|
||||||
channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4
|
channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4
|
||||||
if output_stride == 16:
|
if output_stride == 16:
|
||||||
strides[3] = 1
|
strides[3] = 1
|
||||||
@ -350,11 +356,11 @@ class ResNet(nn.Module):
|
|||||||
llargs = list(zip(channels, layers, strides, dilations))
|
llargs = list(zip(channels, layers, strides, dilations))
|
||||||
lkwargs = dict(
|
lkwargs = dict(
|
||||||
use_se=use_se, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
|
use_se=use_se, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
|
||||||
avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args)
|
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
|
||||||
self.layer1 = self._make_layer(block, *llargs[0], **lkwargs)
|
self.layer1 = self._make_layer(block, *llargs[0], **lkwargs)
|
||||||
self.layer2 = self._make_layer(block, *llargs[1], **lkwargs)
|
self.layer2 = self._make_layer(block, *llargs[1], **lkwargs)
|
||||||
self.layer3 = self._make_layer(block, *llargs[2], **lkwargs)
|
self.layer3 = self._make_layer(block, drop_block=db_3, *llargs[2], **lkwargs)
|
||||||
self.layer4 = self._make_layer(block, *llargs[3], **lkwargs)
|
self.layer4 = self._make_layer(block, drop_block=db_4, *llargs[3], **lkwargs)
|
||||||
|
|
||||||
# Head (Pooling and Classifier)
|
# Head (Pooling and Classifier)
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
|
@ -4,7 +4,7 @@ from torch import nn as nn
|
|||||||
|
|
||||||
from timm.models.registry import register_model
|
from timm.models.registry import register_model
|
||||||
from timm.models.helpers import load_pretrained
|
from timm.models.helpers import load_pretrained
|
||||||
from timm.models.conv2d_layers import SelectiveKernelConv
|
from timm.models.conv2d_layers import SelectiveKernelConv, ConvBnAct
|
||||||
from timm.models.resnet import ResNet, SEModule
|
from timm.models.resnet import ResNet, SEModule
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
@ -29,61 +29,53 @@ default_cfgs = {
|
|||||||
class SelectiveKernelBasic(nn.Module):
|
class SelectiveKernelBasic(nn.Module):
|
||||||
expansion = 1
|
expansion = 1
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
||||||
cardinality=1, base_width=64, use_se=False, sk_kwargs=None,
|
use_se=False, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
|
||||||
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||||
super(SelectiveKernelBasic, self).__init__()
|
super(SelectiveKernelBasic, self).__init__()
|
||||||
|
|
||||||
sk_kwargs = sk_kwargs or {}
|
sk_kwargs = sk_kwargs or {}
|
||||||
|
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
|
||||||
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
||||||
assert base_width == 64, 'BasicBlock doest not support changing base width'
|
assert base_width == 64, 'BasicBlock doest not support changing base width'
|
||||||
first_planes = planes // reduce_first
|
first_planes = planes // reduce_first
|
||||||
outplanes = planes * self.expansion
|
out_planes = planes * self.expansion
|
||||||
first_dilation = first_dilation or dilation
|
first_dilation = first_dilation or dilation
|
||||||
|
|
||||||
_selective_first = True # FIXME temporary, for experiments
|
_selective_first = True # FIXME temporary, for experiments
|
||||||
if _selective_first:
|
if _selective_first:
|
||||||
self.conv1 = SelectiveKernelConv(
|
self.conv1 = SelectiveKernelConv(
|
||||||
inplanes, first_planes, stride=stride, dilation=first_dilation, **sk_kwargs)
|
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
|
||||||
else:
|
conv_kwargs['act_layer'] = None
|
||||||
self.conv1 = nn.Conv2d(
|
self.conv2 = ConvBnAct(
|
||||||
inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
|
first_planes, out_planes, kernel_size=3, dilation=dilation, **conv_kwargs)
|
||||||
dilation=first_dilation, bias=False)
|
|
||||||
self.bn1 = norm_layer(first_planes)
|
|
||||||
self.act1 = act_layer(inplace=True)
|
|
||||||
if _selective_first:
|
|
||||||
self.conv2 = nn.Conv2d(
|
|
||||||
first_planes, outplanes, kernel_size=3, padding=dilation,
|
|
||||||
dilation=dilation, bias=False)
|
|
||||||
else:
|
else:
|
||||||
|
self.conv1 = ConvBnAct(
|
||||||
|
inplanes, first_planes, kernel_size=3, stride=stride, dilation=first_dilation, **conv_kwargs)
|
||||||
|
conv_kwargs['act_layer'] = None
|
||||||
self.conv2 = SelectiveKernelConv(
|
self.conv2 = SelectiveKernelConv(
|
||||||
first_planes, outplanes, dilation=dilation, **sk_kwargs)
|
first_planes, out_planes, dilation=dilation, **conv_kwargs, **sk_kwargs)
|
||||||
self.bn2 = norm_layer(outplanes)
|
self.se = SEModule(out_planes, planes // 4) if use_se else None
|
||||||
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
self.act = act_layer(inplace=True)
|
||||||
self.act2 = act_layer(inplace=True)
|
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.dilation = dilation
|
self.dilation = dilation
|
||||||
|
self.drop_block = drop_block
|
||||||
|
self.drop_path = drop_path
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
residual = x
|
||||||
|
x = self.conv1(x)
|
||||||
out = self.conv1(x)
|
x = self.conv2(x)
|
||||||
out = self.bn1(out)
|
|
||||||
out = self.act1(out)
|
|
||||||
out = self.conv2(out)
|
|
||||||
out = self.bn2(out)
|
|
||||||
|
|
||||||
if self.se is not None:
|
if self.se is not None:
|
||||||
out = self.se(out)
|
x = self.se(x)
|
||||||
|
if self.drop_path is not None:
|
||||||
|
x = self.drop_path(x)
|
||||||
if self.downsample is not None:
|
if self.downsample is not None:
|
||||||
residual = self.downsample(x)
|
residual = self.downsample(residual)
|
||||||
|
x += residual
|
||||||
out += residual
|
x = self.act(x)
|
||||||
out = self.act2(out)
|
return x
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class SelectiveKernelBottleneck(nn.Module):
|
class SelectiveKernelBottleneck(nn.Module):
|
||||||
@ -91,54 +83,46 @@ class SelectiveKernelBottleneck(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||||
cardinality=1, base_width=64, use_se=False, sk_kwargs=None,
|
cardinality=1, base_width=64, use_se=False, sk_kwargs=None,
|
||||||
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
reduce_first=1, dilation=1, first_dilation=None,
|
||||||
|
drop_block=None, drop_path=None,
|
||||||
|
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||||
super(SelectiveKernelBottleneck, self).__init__()
|
super(SelectiveKernelBottleneck, self).__init__()
|
||||||
|
|
||||||
sk_kwargs = sk_kwargs or {}
|
sk_kwargs = sk_kwargs or {}
|
||||||
|
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
|
||||||
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
||||||
first_planes = width // reduce_first
|
first_planes = width // reduce_first
|
||||||
outplanes = planes * self.expansion
|
out_planes = planes * self.expansion
|
||||||
first_dilation = first_dilation or dilation
|
first_dilation = first_dilation or dilation
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
|
self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
|
||||||
self.bn1 = norm_layer(first_planes)
|
|
||||||
self.act1 = act_layer(inplace=True)
|
|
||||||
self.conv2 = SelectiveKernelConv(
|
self.conv2 = SelectiveKernelConv(
|
||||||
first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, **sk_kwargs)
|
first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality,
|
||||||
self.bn2 = norm_layer(width)
|
**conv_kwargs, **sk_kwargs)
|
||||||
self.act2 = act_layer(inplace=True)
|
conv_kwargs['act_layer'] = None
|
||||||
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
self.conv3 = ConvBnAct(width, out_planes, kernel_size=1, **conv_kwargs)
|
||||||
self.bn3 = norm_layer(outplanes)
|
self.se = SEModule(out_planes, planes // 4) if use_se else None
|
||||||
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
self.act = act_layer(inplace=True)
|
||||||
self.act3 = act_layer(inplace=True)
|
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.dilation = dilation
|
self.dilation = dilation
|
||||||
|
self.drop_block = drop_block
|
||||||
|
self.drop_path = drop_path
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
residual = x
|
||||||
|
x = self.conv1(x)
|
||||||
out = self.conv1(x)
|
x = self.conv2(x)
|
||||||
out = self.bn1(out)
|
x = self.conv3(x)
|
||||||
out = self.act1(out)
|
|
||||||
|
|
||||||
out = self.conv2(out)
|
|
||||||
out = self.bn2(out)
|
|
||||||
out = self.act2(out)
|
|
||||||
|
|
||||||
out = self.conv3(out)
|
|
||||||
out = self.bn3(out)
|
|
||||||
|
|
||||||
if self.se is not None:
|
if self.se is not None:
|
||||||
out = self.se(out)
|
x = self.se(x)
|
||||||
|
if self.drop_path is not None:
|
||||||
|
x = self.drop_path(x)
|
||||||
if self.downsample is not None:
|
if self.downsample is not None:
|
||||||
residual = self.downsample(x)
|
residual = self.downsample(residual)
|
||||||
|
x += residual
|
||||||
out += residual
|
x = self.act(x)
|
||||||
out = self.act3(out)
|
return x
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user