mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge branch 'select_kernel' into attention
This commit is contained in:
commit
f54612f648
@ -16,6 +16,7 @@ from .gluon_xception import *
|
||||
from .res2net import *
|
||||
from .dla import *
|
||||
from .hrnet import *
|
||||
from .sknet import *
|
||||
|
||||
from .registry import *
|
||||
from .factory import create_model
|
||||
|
@ -1,3 +1,5 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -100,14 +102,11 @@ def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
||||
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
||||
|
||||
|
||||
class MixedConv2d(nn.Module):
|
||||
class MixedConv2d(nn.ModuleDict):
|
||||
""" Mixed Grouped Convolution
|
||||
Based on MDConv and GroupedConv in MixNet impl:
|
||||
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
|
||||
|
||||
NOTE: This does not currently work with torch.jit.script
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||||
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
|
||||
super(MixedConv2d, self).__init__()
|
||||
@ -131,7 +130,7 @@ class MixedConv2d(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x_split = torch.split(x, self.splits, 1)
|
||||
x_out = [c(x) for x, c in zip(x_split, self._modules.values())]
|
||||
x_out = [c(x_split[i]) for i, c in enumerate(self.values())]
|
||||
x = torch.cat(x_out, 1)
|
||||
return x
|
||||
|
||||
@ -240,6 +239,110 @@ class CondConv2d(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
class SelectiveKernelAttn(nn.Module):
|
||||
def __init__(self, channels, num_paths=2, attn_channels=32,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
super(SelectiveKernelAttn, self).__init__()
|
||||
self.num_paths = num_paths
|
||||
self.pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
|
||||
self.bn = norm_layer(attn_channels)
|
||||
self.act = act_layer(inplace=True)
|
||||
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.num_paths
|
||||
x = torch.sum(x, dim=1)
|
||||
x = self.pool(x)
|
||||
x = self.fc_reduce(x)
|
||||
x = self.bn(x)
|
||||
x = self.act(x)
|
||||
x = self.fc_select(x)
|
||||
B, C, H, W = x.shape
|
||||
x = x.view(B, self.num_paths, C // self.num_paths, H, W)
|
||||
x = torch.softmax(x, dim=1)
|
||||
return x
|
||||
|
||||
|
||||
def _kernel_valid(k):
|
||||
if isinstance(k, (list, tuple)):
|
||||
for ki in k:
|
||||
return _kernel_valid(ki)
|
||||
assert k >= 3 and k % 2
|
||||
|
||||
|
||||
class ConvBnAct(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1,
|
||||
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
super(ConvBnAct, self).__init__()
|
||||
padding = _get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=groups, bias=False)
|
||||
self.bn = norm_layer(out_channels)
|
||||
self.drop_block = drop_block
|
||||
if act_layer is not None:
|
||||
self.act = act_layer(inplace=True)
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
if self.drop_block is not None:
|
||||
x = self.drop_block(x)
|
||||
if self.act is not None:
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class SelectiveKernelConv(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
|
||||
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
|
||||
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
super(SelectiveKernelConv, self).__init__()
|
||||
kernel_size = kernel_size or [3, 5]
|
||||
_kernel_valid(kernel_size)
|
||||
if not isinstance(kernel_size, list):
|
||||
kernel_size = [kernel_size] * 2
|
||||
if keep_3x3:
|
||||
dilation = [dilation * (k - 1) // 2 for k in kernel_size]
|
||||
kernel_size = [3] * len(kernel_size)
|
||||
else:
|
||||
dilation = [dilation] * len(kernel_size)
|
||||
self.num_paths = len(kernel_size)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.split_input = split_input
|
||||
if self.split_input:
|
||||
assert in_channels % self.num_paths == 0
|
||||
in_channels = in_channels // self.num_paths
|
||||
groups = min(out_channels, groups)
|
||||
|
||||
conv_kwargs = dict(
|
||||
stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
|
||||
self.paths = nn.ModuleList([
|
||||
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
|
||||
for k, d in zip(kernel_size, dilation)])
|
||||
|
||||
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
|
||||
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
|
||||
self.drop_block = drop_block
|
||||
|
||||
def forward(self, x):
|
||||
if self.split_input:
|
||||
x_split = torch.split(x, self.in_channels // self.num_paths, 1)
|
||||
x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
|
||||
else:
|
||||
x_paths = [op(x) for op in self.paths]
|
||||
x = torch.stack(x_paths, dim=1)
|
||||
x_attn = self.attn(x)
|
||||
x = x * x_attn
|
||||
x = torch.sum(x, dim=1)
|
||||
return x
|
||||
|
||||
|
||||
# helper method
|
||||
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
|
||||
assert 'groups' not in kwargs # only use 'depthwise' bool arg
|
||||
@ -256,5 +359,3 @@ def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
|
||||
else:
|
||||
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
||||
return m
|
||||
|
||||
|
||||
|
77
timm/models/nn_ops.py
Normal file
77
timm/models/nn_ops.py
Normal file
@ -0,0 +1,77 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
|
||||
def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noise=False):
|
||||
_, _, height, width = x.shape
|
||||
total_size = width * height
|
||||
clipped_block_size = min(block_size, min(width, height))
|
||||
# seed_drop_rate, the gamma parameter
|
||||
seed_drop_rate = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
|
||||
(width - block_size + 1) *
|
||||
(height - block_size + 1))
|
||||
|
||||
# Forces the block to be inside the feature map.
|
||||
w_i, h_i = torch.meshgrid(torch.arange(width).to(x.device), torch.arange(height).to(x.device))
|
||||
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < width - (clipped_block_size - 1) // 2)) & \
|
||||
((h_i >= clipped_block_size // 2) & (h_i < height - (clipped_block_size - 1) // 2))
|
||||
valid_block = torch.reshape(valid_block, (1, 1, height, width)).float()
|
||||
|
||||
uniform_noise = torch.rand_like(x, dtype=torch.float32)
|
||||
block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).float()
|
||||
block_mask = -F.max_pool2d(
|
||||
-block_mask,
|
||||
kernel_size=clipped_block_size, # block_size,
|
||||
stride=1,
|
||||
padding=clipped_block_size // 2)
|
||||
|
||||
if drop_with_noise:
|
||||
normal_noise = torch.randn_like(x)
|
||||
x = x * block_mask + normal_noise * (1 - block_mask)
|
||||
else:
|
||||
normalize_scale = block_mask.numel() / (torch.sum(block_mask) + 1e-7)
|
||||
x = x * block_mask * normalize_scale
|
||||
return x
|
||||
|
||||
|
||||
class DropBlock2d(nn.Module):
|
||||
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||
"""
|
||||
def __init__(self,
|
||||
drop_prob=0.1,
|
||||
block_size=7,
|
||||
gamma_scale=1.0,
|
||||
with_noise=False):
|
||||
super(DropBlock2d, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.gamma_scale = gamma_scale
|
||||
self.block_size = block_size
|
||||
self.with_noise = with_noise
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training or not self.drop_prob:
|
||||
return x
|
||||
return drop_block_2d(x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise)
|
||||
|
||||
|
||||
def drop_path(x, drop_prob=0.):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in residual blocks)."""
|
||||
keep_prob = 1 - drop_prob
|
||||
random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.ModuleDict):
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training or not self.drop_prob:
|
||||
return x
|
||||
return drop_path(x, self.drop_prob)
|
@ -54,14 +54,15 @@ class Bottle2neck(nn.Module):
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||
cardinality=1, base_width=26, scale=4, use_se=False,
|
||||
act_layer=nn.ReLU, norm_layer=None, dilation=1, previous_dilation=1, **_):
|
||||
act_layer=nn.ReLU, norm_layer=None, dilation=1, first_dilation=None, **_):
|
||||
super(Bottle2neck, self).__init__()
|
||||
self.scale = scale
|
||||
self.is_first = stride > 1 or downsample is not None
|
||||
self.num_scales = max(1, scale - 1)
|
||||
width = int(math.floor(planes * (base_width / 64.0))) * cardinality
|
||||
outplanes = planes * self.expansion
|
||||
self.width = width
|
||||
outplanes = planes * self.expansion
|
||||
first_dilation = first_dilation or dilation
|
||||
|
||||
self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False)
|
||||
self.bn1 = norm_layer(width * scale)
|
||||
@ -70,8 +71,8 @@ class Bottle2neck(nn.Module):
|
||||
bns = []
|
||||
for i in range(self.num_scales):
|
||||
convs.append(nn.Conv2d(
|
||||
width, width, kernel_size=3, stride=stride, padding=dilation,
|
||||
dilation=dilation, groups=cardinality, bias=False))
|
||||
width, width, kernel_size=3, stride=stride, padding=first_dilation,
|
||||
dilation=first_dilation, groups=cardinality, bias=False))
|
||||
bns.append(norm_layer(width))
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
@ -86,6 +87,9 @@ class Bottle2neck(nn.Module):
|
||||
self.relu = act_layer(inplace=True)
|
||||
self.downsample = downsample
|
||||
|
||||
def zero_init_last_bn(self):
|
||||
nn.init.zeros_(self.bn3.weight)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
|
@ -15,6 +15,7 @@ from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from .layers import EcaModule
|
||||
from .nn_ops import DropBlock2d, DropPath
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
|
||||
@ -108,7 +109,7 @@ default_cfgs = {
|
||||
}
|
||||
|
||||
|
||||
def _get_padding(kernel_size, stride, dilation=1):
|
||||
def get_padding(kernel_size, stride, dilation=1):
|
||||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||
return padding
|
||||
|
||||
@ -136,115 +137,135 @@ class BasicBlock(nn.Module):
|
||||
__constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||
cardinality=1, base_width=64, use_se=False, use_eca = False,
|
||||
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False,
|
||||
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
||||
drop_block=None, drop_path=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
|
||||
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
||||
assert base_width == 64, 'BasicBlock doest not support changing base width'
|
||||
first_planes = planes // reduce_first
|
||||
outplanes = planes * self.expansion
|
||||
first_dilation = first_dilation or dilation
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation,
|
||||
dilation=dilation, bias=False)
|
||||
inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
|
||||
dilation=first_dilation, bias=False)
|
||||
self.bn1 = norm_layer(first_planes)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.conv2 = nn.Conv2d(
|
||||
first_planes, outplanes, kernel_size=3, padding=previous_dilation,
|
||||
dilation=previous_dilation, bias=False)
|
||||
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
|
||||
self.bn2 = norm_layer(outplanes)
|
||||
|
||||
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
||||
self.eca = EcaModule(outplanes) if use_eca else None
|
||||
|
||||
self.act2 = act_layer(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.drop_block = drop_block
|
||||
self.drop_path = drop_path
|
||||
|
||||
def zero_init_last_bn(self):
|
||||
nn.init.zeros_(self.bn2.weight)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.act1(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
if self.drop_block is not None:
|
||||
x = self.drop_block(x)
|
||||
x = self.act1(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
if self.drop_block is not None:
|
||||
x = self.drop_block(x)
|
||||
|
||||
if self.se is not None:
|
||||
out = self.se(out)
|
||||
if self.eca is not None:
|
||||
out = self.eca(out)
|
||||
x = self.se(x)
|
||||
|
||||
if self.drop_path is not None:
|
||||
x = self.drop_path(x)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
residual = self.downsample(residual)
|
||||
x += residual
|
||||
x = self.act2(x)
|
||||
|
||||
out += residual
|
||||
out = self.act2(out)
|
||||
|
||||
return out
|
||||
return x
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
__constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||
cardinality=1, base_width=64, use_se=False, use_eca=False,
|
||||
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False,
|
||||
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
||||
drop_block=None, drop_path=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
|
||||
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
||||
first_planes = width // reduce_first
|
||||
outplanes = planes * self.expansion
|
||||
first_dilation = first_dilation or dilation
|
||||
|
||||
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
|
||||
self.bn1 = norm_layer(first_planes)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.conv2 = nn.Conv2d(
|
||||
first_planes, width, kernel_size=3, stride=stride,
|
||||
padding=dilation, dilation=dilation, groups=cardinality, bias=False)
|
||||
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
||||
self.bn3 = norm_layer(outplanes)
|
||||
|
||||
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
||||
self.eca = EcaModule(outplanes) if use_eca else None
|
||||
|
||||
|
||||
self.act3 = act_layer(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.drop_block = drop_block
|
||||
self.drop_path = drop_path
|
||||
|
||||
def zero_init_last_bn(self):
|
||||
nn.init.zeros_(self.bn3.weight)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.act1(out)
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
if self.drop_block is not None:
|
||||
x = self.drop_block(x)
|
||||
x = self.act1(x)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.act2(out)
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
if self.drop_block is not None:
|
||||
x = self.drop_block(x)
|
||||
x = self.act2(x)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
x = self.conv3(x)
|
||||
x = self.bn3(x)
|
||||
if self.drop_block is not None:
|
||||
x = self.drop_block(x)
|
||||
|
||||
if self.se is not None:
|
||||
out = self.se(out)
|
||||
if self.eca is not None:
|
||||
out = self.eca(out)
|
||||
x = self.se(x)
|
||||
|
||||
if self.drop_path is not None:
|
||||
x = self.drop_path(x)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
residual = self.downsample(residual)
|
||||
x += residual
|
||||
x = self.act3(x)
|
||||
|
||||
out += residual
|
||||
out = self.act3(out)
|
||||
|
||||
return out
|
||||
return x
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
@ -290,8 +311,6 @@ class ResNet(nn.Module):
|
||||
Number of input (color) channels.
|
||||
use_se : bool, default False
|
||||
Enable Squeeze-Excitation module in blocks
|
||||
use_eca : bool, default False
|
||||
Enable ECA module in blocks
|
||||
cardinality : int, default 1
|
||||
Number of convolution groups for 3x3 conv in Bottleneck.
|
||||
base_width : int, default 64
|
||||
@ -323,8 +342,8 @@ class ResNet(nn.Module):
|
||||
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, use_eca=False,
|
||||
cardinality=1, base_width=64, stem_width=64, stem_type='',
|
||||
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg',
|
||||
zero_init_last_bn=True, block_args=None):
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0.,
|
||||
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
|
||||
block_args = block_args or dict()
|
||||
self.num_classes = num_classes
|
||||
deep_stem = 'deep' in stem_type
|
||||
@ -356,6 +375,9 @@ class ResNet(nn.Module):
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
# Feature Blocks
|
||||
dp = DropPath(drop_path_rate) if drop_block_rate else None
|
||||
db_3 = DropBlock2d(drop_block_rate, 7, 0.25) if drop_block_rate else None
|
||||
db_4 = DropBlock2d(drop_block_rate, 7, 1.00) if drop_block_rate else None
|
||||
channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4
|
||||
if output_stride == 16:
|
||||
strides[3] = 1
|
||||
@ -367,29 +389,28 @@ class ResNet(nn.Module):
|
||||
assert output_stride == 32
|
||||
llargs = list(zip(channels, layers, strides, dilations))
|
||||
lkwargs = dict(
|
||||
use_se=use_se, use_eca=use_eca, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
|
||||
avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args)
|
||||
use_se=use_se, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
|
||||
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
|
||||
self.layer1 = self._make_layer(block, *llargs[0], **lkwargs)
|
||||
self.layer2 = self._make_layer(block, *llargs[1], **lkwargs)
|
||||
self.layer3 = self._make_layer(block, *llargs[2], **lkwargs)
|
||||
self.layer4 = self._make_layer(block, *llargs[3], **lkwargs)
|
||||
self.layer3 = self._make_layer(block, drop_block=db_3, *llargs[2], **lkwargs)
|
||||
self.layer4 = self._make_layer(block, drop_block=db_4, *llargs[3], **lkwargs)
|
||||
|
||||
# Head (Pooling and Classifier)
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.num_features = 512 * block.expansion
|
||||
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||
|
||||
last_bn_name = 'bn3' if 'Bottle' in block.__name__ else 'bn2'
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
if zero_init_last_bn and 'layer' in n and last_bn_name in n:
|
||||
# Initialize weight/gamma of last BN in each residual block to zero
|
||||
nn.init.constant_(m.weight, 0.)
|
||||
else:
|
||||
nn.init.constant_(m.weight, 1.)
|
||||
nn.init.constant_(m.weight, 1.)
|
||||
nn.init.constant_(m.bias, 0.)
|
||||
if zero_init_last_bn:
|
||||
for m in self.modules():
|
||||
if hasattr(m, 'zero_init_last_bn'):
|
||||
m.zero_init_last_bn()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
|
||||
use_se=False, use_eca=False,avg_down=False, down_kernel_size=1, **kwargs):
|
||||
@ -397,7 +418,7 @@ class ResNet(nn.Module):
|
||||
downsample = None
|
||||
down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample_padding = _get_padding(down_kernel_size, stride)
|
||||
downsample_padding = get_padding(down_kernel_size, stride)
|
||||
downsample_layers = []
|
||||
conv_stride = stride
|
||||
if avg_down:
|
||||
@ -413,13 +434,10 @@ class ResNet(nn.Module):
|
||||
first_dilation = 1 if dilation in (1, 2) else 2
|
||||
bkwargs = dict(
|
||||
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
|
||||
use_se=use_se, use_eca=use_eca, **kwargs)
|
||||
layers = [block(
|
||||
self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bkwargs)]
|
||||
dilation=dilation, use_se=use_se, **kwargs)
|
||||
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **bkwargs)]
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(
|
||||
self.inplanes, planes, dilation=dilation, previous_dilation=dilation, **bkwargs))
|
||||
layers += [block(self.inplanes, planes, **bkwargs) for _ in range(1, blocks)]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
@ -447,8 +465,8 @@ class ResNet(nn.Module):
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.global_pool(x).flatten(1)
|
||||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
if self.drop_rate:
|
||||
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
242
timm/models/sknet.py
Normal file
242
timm/models/sknet.py
Normal file
@ -0,0 +1,242 @@
|
||||
import math
|
||||
|
||||
from torch import nn as nn
|
||||
|
||||
from timm.models.registry import register_model
|
||||
from timm.models.helpers import load_pretrained
|
||||
from timm.models.conv2d_layers import SelectiveKernelConv, ConvBnAct
|
||||
from timm.models.resnet import ResNet, SEModule
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv1', 'classifier': 'fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'skresnet18': _cfg(url=''),
|
||||
'skresnet26d': _cfg(),
|
||||
'skresnet50': _cfg(),
|
||||
'skresnet50d': _cfg(),
|
||||
'skresnext50_32x4d': _cfg(),
|
||||
}
|
||||
|
||||
|
||||
class SelectiveKernelBasic(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
||||
use_se=False, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
|
||||
drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
super(SelectiveKernelBasic, self).__init__()
|
||||
|
||||
sk_kwargs = sk_kwargs or {}
|
||||
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
|
||||
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
||||
assert base_width == 64, 'BasicBlock doest not support changing base width'
|
||||
first_planes = planes // reduce_first
|
||||
out_planes = planes * self.expansion
|
||||
first_dilation = first_dilation or dilation
|
||||
|
||||
_selective_first = True # FIXME temporary, for experiments
|
||||
if _selective_first:
|
||||
self.conv1 = SelectiveKernelConv(
|
||||
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
|
||||
conv_kwargs['act_layer'] = None
|
||||
self.conv2 = ConvBnAct(
|
||||
first_planes, out_planes, kernel_size=3, dilation=dilation, **conv_kwargs)
|
||||
else:
|
||||
self.conv1 = ConvBnAct(
|
||||
inplanes, first_planes, kernel_size=3, stride=stride, dilation=first_dilation, **conv_kwargs)
|
||||
conv_kwargs['act_layer'] = None
|
||||
self.conv2 = SelectiveKernelConv(
|
||||
first_planes, out_planes, dilation=dilation, **conv_kwargs, **sk_kwargs)
|
||||
self.se = SEModule(out_planes, planes // 4) if use_se else None
|
||||
self.act = act_layer(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.drop_block = drop_block
|
||||
self.drop_path = drop_path
|
||||
|
||||
def zero_init_last_bn(self):
|
||||
nn.init.zeros_(self.conv2.bn.weight)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
if self.se is not None:
|
||||
x = self.se(x)
|
||||
if self.drop_path is not None:
|
||||
x = self.drop_path(x)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(residual)
|
||||
x += residual
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class SelectiveKernelBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||
cardinality=1, base_width=64, use_se=False, sk_kwargs=None,
|
||||
reduce_first=1, dilation=1, first_dilation=None,
|
||||
drop_block=None, drop_path=None,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
super(SelectiveKernelBottleneck, self).__init__()
|
||||
|
||||
sk_kwargs = sk_kwargs or {}
|
||||
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
|
||||
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
||||
first_planes = width // reduce_first
|
||||
out_planes = planes * self.expansion
|
||||
first_dilation = first_dilation or dilation
|
||||
|
||||
self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
|
||||
self.conv2 = SelectiveKernelConv(
|
||||
first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality,
|
||||
**conv_kwargs, **sk_kwargs)
|
||||
conv_kwargs['act_layer'] = None
|
||||
self.conv3 = ConvBnAct(width, out_planes, kernel_size=1, **conv_kwargs)
|
||||
self.se = SEModule(out_planes, planes // 4) if use_se else None
|
||||
self.act = act_layer(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.drop_block = drop_block
|
||||
self.drop_path = drop_path
|
||||
|
||||
def zero_init_last_bn(self):
|
||||
nn.init.zeros_(self.conv3.bn.weight)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
if self.se is not None:
|
||||
x = self.se(x)
|
||||
if self.drop_path is not None:
|
||||
x = self.drop_path(x)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(residual)
|
||||
x += residual
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
@register_model
|
||||
def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-18 model.
|
||||
"""
|
||||
default_cfg = default_cfgs['skresnet18']
|
||||
sk_kwargs = dict(
|
||||
min_attn_channels=16,
|
||||
)
|
||||
model = ResNet(
|
||||
SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
|
||||
block_args=dict(sk_kwargs=sk_kwargs), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-18 model.
|
||||
"""
|
||||
default_cfg = default_cfgs['skresnet18']
|
||||
sk_kwargs = dict(
|
||||
min_attn_channels=16,
|
||||
attn_reduction=8,
|
||||
split_input=True
|
||||
)
|
||||
model = ResNet(
|
||||
SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
|
||||
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-26 model.
|
||||
"""
|
||||
default_cfg = default_cfgs['skresnet26d']
|
||||
sk_kwargs = dict(
|
||||
keep_3x3=False,
|
||||
)
|
||||
model = ResNet(
|
||||
SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False
|
||||
**kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a Select Kernel ResNet-50 model.
|
||||
Based on config in "Compounding the Performance Improvements of Assembled Techniques in a
|
||||
Convolutional Neural Network"
|
||||
"""
|
||||
sk_kwargs = dict(
|
||||
attn_reduction=2,
|
||||
)
|
||||
default_cfg = default_cfgs['skresnet50']
|
||||
model = ResNet(
|
||||
SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
||||
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a Select Kernel ResNet-50-D model.
|
||||
Based on config in "Compounding the Performance Improvements of Assembled Techniques in a
|
||||
Convolutional Neural Network"
|
||||
"""
|
||||
sk_kwargs = dict(
|
||||
attn_reduction=2,
|
||||
)
|
||||
default_cfg = default_cfgs['skresnet50d']
|
||||
model = ResNet(
|
||||
SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs),
|
||||
zero_init_last_bn=False, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
|
||||
the SKNet50 model in the Select Kernel Paper
|
||||
"""
|
||||
default_cfg = default_cfgs['skresnext50_32x4d']
|
||||
model = ResNet(
|
||||
SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
||||
num_classes=num_classes, in_chans=in_chans, zero_init_last_bn=False, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
Loading…
x
Reference in New Issue
Block a user