mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Move SelectKernelConv to conv2d_layers and more
* always apply attention in SelectKernelConv, leave MixedConv for no attention alternative * make MixedConv torchscript compatible * refactor first/previous dilation name to make more sense in ResNet* networks
This commit is contained in:
parent
9abe610931
commit
cefc9b7761
@ -1,3 +1,5 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -100,14 +102,11 @@ def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
|||||||
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class MixedConv2d(nn.Module):
|
class MixedConv2d(nn.ModuleDict):
|
||||||
""" Mixed Grouped Convolution
|
""" Mixed Grouped Convolution
|
||||||
Based on MDConv and GroupedConv in MixNet impl:
|
Based on MDConv and GroupedConv in MixNet impl:
|
||||||
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
|
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
|
||||||
|
|
||||||
NOTE: This does not currently work with torch.jit.script
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3,
|
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||||||
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
|
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
|
||||||
super(MixedConv2d, self).__init__()
|
super(MixedConv2d, self).__init__()
|
||||||
@ -131,7 +130,7 @@ class MixedConv2d(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x_split = torch.split(x, self.splits, 1)
|
x_split = torch.split(x, self.splits, 1)
|
||||||
x_out = [c(x) for x, c in zip(x_split, self._modules.values())]
|
x_out = [c(x_split[i]) for i, c in enumerate(self.values())]
|
||||||
x = torch.cat(x_out, 1)
|
x = torch.cat(x_out, 1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -240,6 +239,97 @@ class CondConv2d(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SelectiveKernelAttn(nn.Module):
|
||||||
|
def __init__(self, channels, num_paths=2, attn_channels=32,
|
||||||
|
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||||
|
super(SelectiveKernelAttn, self).__init__()
|
||||||
|
self.num_paths = num_paths
|
||||||
|
self.pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
|
||||||
|
self.bn = norm_layer(attn_channels)
|
||||||
|
self.act = act_layer(inplace=True)
|
||||||
|
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert x.shape[1] == self.num_paths
|
||||||
|
x = torch.sum(x, dim=1)
|
||||||
|
x = self.pool(x)
|
||||||
|
x = self.fc_reduce(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.fc_select(x)
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
x = x.view(B, self.num_paths, C // self.num_paths, H, W)
|
||||||
|
x = torch.softmax(x, dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _kernel_valid(k):
|
||||||
|
if isinstance(k, (list, tuple)):
|
||||||
|
for ki in k:
|
||||||
|
return _kernel_valid(ki)
|
||||||
|
assert k >= 3 and k % 2
|
||||||
|
|
||||||
|
|
||||||
|
class SelectiveKernelConv(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
|
||||||
|
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
|
||||||
|
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||||
|
super(SelectiveKernelConv, self).__init__()
|
||||||
|
kernel_size = kernel_size or [3, 5]
|
||||||
|
_kernel_valid(kernel_size)
|
||||||
|
if not isinstance(kernel_size, list):
|
||||||
|
kernel_size = [kernel_size] * 2
|
||||||
|
if keep_3x3:
|
||||||
|
dilation = [dilation * (k - 1) // 2 for k in kernel_size]
|
||||||
|
kernel_size = [3] * len(kernel_size)
|
||||||
|
else:
|
||||||
|
dilation = [dilation] * len(kernel_size)
|
||||||
|
num_paths = len(kernel_size)
|
||||||
|
self.num_paths = num_paths
|
||||||
|
self.split_input = split_input
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
if split_input:
|
||||||
|
assert in_channels % num_paths == 0 and out_channels % num_paths == 0
|
||||||
|
in_channels = in_channels // num_paths
|
||||||
|
out_channels = out_channels // num_paths
|
||||||
|
groups = min(out_channels, groups)
|
||||||
|
|
||||||
|
self.paths = nn.ModuleList()
|
||||||
|
for k, d in zip(kernel_size, dilation):
|
||||||
|
p = _get_padding(k, stride, d)
|
||||||
|
self.paths.append(nn.Sequential(OrderedDict([
|
||||||
|
('conv', nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=k, stride=stride, padding=p,
|
||||||
|
dilation=d, groups=groups, bias=False)),
|
||||||
|
('bn', norm_layer(out_channels)),
|
||||||
|
('act', act_layer(inplace=True))
|
||||||
|
])))
|
||||||
|
|
||||||
|
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
|
||||||
|
self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.split_input:
|
||||||
|
x_split = torch.split(x, self.in_channels // self.num_paths, 1)
|
||||||
|
x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
|
||||||
|
else:
|
||||||
|
x_paths = [op(x) for op in self.paths]
|
||||||
|
|
||||||
|
x = torch.stack(x_paths, dim=1)
|
||||||
|
x_attn = self.attn(x)
|
||||||
|
x = x * x_attn
|
||||||
|
|
||||||
|
if self.split_input:
|
||||||
|
B, N, C, H, W = x.shape
|
||||||
|
x = x.reshape(B, N * C, H, W)
|
||||||
|
else:
|
||||||
|
x = torch.sum(x, dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
# helper method
|
# helper method
|
||||||
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
|
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
|
||||||
assert 'groups' not in kwargs # only use 'depthwise' bool arg
|
assert 'groups' not in kwargs # only use 'depthwise' bool arg
|
||||||
@ -256,5 +346,3 @@ def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
|
|||||||
else:
|
else:
|
||||||
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,14 +54,15 @@ class Bottle2neck(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||||
cardinality=1, base_width=26, scale=4, use_se=False,
|
cardinality=1, base_width=26, scale=4, use_se=False,
|
||||||
act_layer=nn.ReLU, norm_layer=None, dilation=1, previous_dilation=1, **_):
|
act_layer=nn.ReLU, norm_layer=None, dilation=1, first_dilation=None, **_):
|
||||||
super(Bottle2neck, self).__init__()
|
super(Bottle2neck, self).__init__()
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.is_first = stride > 1 or downsample is not None
|
self.is_first = stride > 1 or downsample is not None
|
||||||
self.num_scales = max(1, scale - 1)
|
self.num_scales = max(1, scale - 1)
|
||||||
width = int(math.floor(planes * (base_width / 64.0))) * cardinality
|
width = int(math.floor(planes * (base_width / 64.0))) * cardinality
|
||||||
outplanes = planes * self.expansion
|
|
||||||
self.width = width
|
self.width = width
|
||||||
|
outplanes = planes * self.expansion
|
||||||
|
first_dilation = first_dilation or dilation
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False)
|
self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False)
|
||||||
self.bn1 = norm_layer(width * scale)
|
self.bn1 = norm_layer(width * scale)
|
||||||
@ -70,8 +71,8 @@ class Bottle2neck(nn.Module):
|
|||||||
bns = []
|
bns = []
|
||||||
for i in range(self.num_scales):
|
for i in range(self.num_scales):
|
||||||
convs.append(nn.Conv2d(
|
convs.append(nn.Conv2d(
|
||||||
width, width, kernel_size=3, stride=stride, padding=dilation,
|
width, width, kernel_size=3, stride=stride, padding=first_dilation,
|
||||||
dilation=dilation, groups=cardinality, bias=False))
|
dilation=first_dilation, groups=cardinality, bias=False))
|
||||||
bns.append(norm_layer(width))
|
bns.append(norm_layer(width))
|
||||||
self.convs = nn.ModuleList(convs)
|
self.convs = nn.ModuleList(convs)
|
||||||
self.bns = nn.ModuleList(bns)
|
self.bns = nn.ModuleList(bns)
|
||||||
|
@ -131,24 +131,23 @@ class BasicBlock(nn.Module):
|
|||||||
__constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat
|
__constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat
|
||||||
expansion = 1
|
expansion = 1
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False,
|
||||||
cardinality=1, base_width=64, use_se=False,
|
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||||
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
|
||||||
super(BasicBlock, self).__init__()
|
super(BasicBlock, self).__init__()
|
||||||
|
|
||||||
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
||||||
assert base_width == 64, 'BasicBlock doest not support changing base width'
|
assert base_width == 64, 'BasicBlock doest not support changing base width'
|
||||||
first_planes = planes // reduce_first
|
first_planes = planes // reduce_first
|
||||||
outplanes = planes * self.expansion
|
outplanes = planes * self.expansion
|
||||||
|
first_dilation = first_dilation or dilation
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(
|
self.conv1 = nn.Conv2d(
|
||||||
inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation,
|
inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
|
||||||
dilation=dilation, bias=False)
|
dilation=first_dilation, bias=False)
|
||||||
self.bn1 = norm_layer(first_planes)
|
self.bn1 = norm_layer(first_planes)
|
||||||
self.act1 = act_layer(inplace=True)
|
self.act1 = act_layer(inplace=True)
|
||||||
self.conv2 = nn.Conv2d(
|
self.conv2 = nn.Conv2d(
|
||||||
first_planes, outplanes, kernel_size=3, padding=previous_dilation,
|
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
|
||||||
dilation=previous_dilation, bias=False)
|
|
||||||
self.bn2 = norm_layer(outplanes)
|
self.bn2 = norm_layer(outplanes)
|
||||||
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
||||||
self.act2 = act_layer(inplace=True)
|
self.act2 = act_layer(inplace=True)
|
||||||
@ -181,21 +180,21 @@ class Bottleneck(nn.Module):
|
|||||||
__constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat
|
__constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat
|
||||||
expansion = 4
|
expansion = 4
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False,
|
||||||
cardinality=1, base_width=64, use_se=False,
|
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||||
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
|
||||||
super(Bottleneck, self).__init__()
|
super(Bottleneck, self).__init__()
|
||||||
|
|
||||||
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
||||||
first_planes = width // reduce_first
|
first_planes = width // reduce_first
|
||||||
outplanes = planes * self.expansion
|
outplanes = planes * self.expansion
|
||||||
|
first_dilation = first_dilation or dilation
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
|
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
|
||||||
self.bn1 = norm_layer(first_planes)
|
self.bn1 = norm_layer(first_planes)
|
||||||
self.act1 = act_layer(inplace=True)
|
self.act1 = act_layer(inplace=True)
|
||||||
self.conv2 = nn.Conv2d(
|
self.conv2 = nn.Conv2d(
|
||||||
first_planes, width, kernel_size=3, stride=stride,
|
first_planes, width, kernel_size=3, stride=stride,
|
||||||
padding=dilation, dilation=dilation, groups=cardinality, bias=False)
|
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
|
||||||
self.bn2 = norm_layer(width)
|
self.bn2 = norm_layer(width)
|
||||||
self.act2 = act_layer(inplace=True)
|
self.act2 = act_layer(inplace=True)
|
||||||
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
||||||
@ -396,13 +395,11 @@ class ResNet(nn.Module):
|
|||||||
first_dilation = 1 if dilation in (1, 2) else 2
|
first_dilation = 1 if dilation in (1, 2) else 2
|
||||||
bkwargs = dict(
|
bkwargs = dict(
|
||||||
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
|
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
|
||||||
use_se=use_se, **kwargs)
|
dilation=dilation, use_se=use_se, **kwargs)
|
||||||
layers = [block(
|
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **bkwargs)]
|
||||||
self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bkwargs)]
|
|
||||||
self.inplanes = planes * block.expansion
|
self.inplanes = planes * block.expansion
|
||||||
for i in range(1, blocks):
|
for i in range(1, blocks):
|
||||||
layers.append(block(
|
layers.append(block(self.inplanes, planes, **bkwargs))
|
||||||
self.inplanes, planes, dilation=dilation, previous_dilation=dilation, **bkwargs))
|
|
||||||
|
|
||||||
return nn.Sequential(*layers)
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
@ -430,8 +427,8 @@ class ResNet(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.forward_features(x)
|
x = self.forward_features(x)
|
||||||
x = self.global_pool(x).flatten(1)
|
x = self.global_pool(x).flatten(1)
|
||||||
if self.drop_rate > 0.:
|
if self.drop_rate:
|
||||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
|
||||||
x = self.fc(x)
|
x = self.fc(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
import math
|
import math
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
from timm.models.registry import register_model
|
from timm.models.registry import register_model
|
||||||
from timm.models.helpers import load_pretrained
|
from timm.models.helpers import load_pretrained
|
||||||
from timm.models.resnet import ResNet, get_padding, SEModule
|
from timm.models.conv2d_layers import SelectiveKernelConv
|
||||||
|
from timm.models.resnet import ResNet, SEModule
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
|
|
||||||
@ -27,113 +26,12 @@ default_cfgs = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class SelectiveKernelAttn(nn.Module):
|
|
||||||
def __init__(self, channels, num_paths=2, attn_channels=32,
|
|
||||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
|
||||||
super(SelectiveKernelAttn, self).__init__()
|
|
||||||
self.num_paths = num_paths
|
|
||||||
self.pool = nn.AdaptiveAvgPool2d(1)
|
|
||||||
self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
|
|
||||||
self.bn = norm_layer(attn_channels)
|
|
||||||
self.act = act_layer(inplace=True)
|
|
||||||
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
assert x.shape[1] == self.num_paths
|
|
||||||
x = torch.sum(x, dim=1)
|
|
||||||
#print('attn sum', x.shape)
|
|
||||||
x = self.pool(x)
|
|
||||||
#print('attn pool', x.shape)
|
|
||||||
x = self.fc_reduce(x)
|
|
||||||
x = self.bn(x)
|
|
||||||
x = self.act(x)
|
|
||||||
x = self.fc_select(x)
|
|
||||||
#print('attn sel', x.shape)
|
|
||||||
B, C, H, W = x.shape
|
|
||||||
x = x.view(B, self.num_paths, C // self.num_paths, H, W)
|
|
||||||
#print('attn spl', x.shape)
|
|
||||||
x = torch.softmax(x, dim=1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def _kernel_valid(k):
|
|
||||||
if isinstance(k, (list, tuple)):
|
|
||||||
for ki in k:
|
|
||||||
return _kernel_valid(ki)
|
|
||||||
assert k >= 3 and k % 2
|
|
||||||
|
|
||||||
|
|
||||||
class SelectiveKernelConv(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=[3, 5], stride=1, dilation=1, groups=1,
|
|
||||||
attn_reduction=16, min_attn_channels=32, keep_3x3=True, use_attn=True,
|
|
||||||
split_input=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
|
||||||
super(SelectiveKernelConv, self).__init__()
|
|
||||||
_kernel_valid(kernel_size)
|
|
||||||
if not isinstance(kernel_size, list):
|
|
||||||
kernel_size = [kernel_size] * 2
|
|
||||||
if keep_3x3:
|
|
||||||
dilation = [dilation * (k - 1) // 2 for k in kernel_size]
|
|
||||||
kernel_size = [3] * len(kernel_size)
|
|
||||||
else:
|
|
||||||
dilation = [dilation] * len(kernel_size)
|
|
||||||
num_paths = len(kernel_size)
|
|
||||||
self.num_paths = num_paths
|
|
||||||
self.split_input = split_input
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
if split_input:
|
|
||||||
assert in_channels % num_paths == 0 and out_channels % num_paths == 0
|
|
||||||
in_channels = in_channels // num_paths
|
|
||||||
out_channels = out_channels // num_paths
|
|
||||||
groups = min(out_channels, groups)
|
|
||||||
|
|
||||||
self.paths = nn.ModuleList()
|
|
||||||
for k, d in zip(kernel_size, dilation):
|
|
||||||
p = get_padding(k, stride, d)
|
|
||||||
self.paths.append(nn.Sequential(OrderedDict([
|
|
||||||
('conv', nn.Conv2d(
|
|
||||||
in_channels, out_channels, kernel_size=k, stride=stride, padding=p, dilation=d, groups=groups)),
|
|
||||||
('bn', norm_layer(out_channels)),
|
|
||||||
('act', act_layer(inplace=True))
|
|
||||||
])))
|
|
||||||
|
|
||||||
if use_attn:
|
|
||||||
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
|
|
||||||
self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels)
|
|
||||||
else:
|
|
||||||
self.attn = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.split_input:
|
|
||||||
x_split = torch.split(x, self.in_channels // self.num_paths, 1)
|
|
||||||
x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
|
|
||||||
else:
|
|
||||||
x_paths = [op(x) for op in self.paths]
|
|
||||||
|
|
||||||
if self.attn is not None:
|
|
||||||
x = torch.stack(x_paths, dim=1)
|
|
||||||
# print('paths', x_paths.shape)
|
|
||||||
x_attn = self.attn(x)
|
|
||||||
#print('attn', x_attn.shape)
|
|
||||||
x = x * x_attn
|
|
||||||
#print('amul', x.shape)
|
|
||||||
|
|
||||||
if self.split_input:
|
|
||||||
B, N, C, H, W = x.shape
|
|
||||||
x = x.reshape(B, N * C, H, W)
|
|
||||||
else:
|
|
||||||
x = torch.sum(x, dim=1)
|
|
||||||
#print('aout', x.shape)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SelectiveKernelBasic(nn.Module):
|
class SelectiveKernelBasic(nn.Module):
|
||||||
expansion = 1
|
expansion = 1
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||||
cardinality=1, base_width=64, use_se=False, sk_kwargs=None,
|
cardinality=1, base_width=64, use_se=False, sk_kwargs=None,
|
||||||
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||||
super(SelectiveKernelBasic, self).__init__()
|
super(SelectiveKernelBasic, self).__init__()
|
||||||
|
|
||||||
sk_kwargs = sk_kwargs or {}
|
sk_kwargs = sk_kwargs or {}
|
||||||
@ -141,24 +39,25 @@ class SelectiveKernelBasic(nn.Module):
|
|||||||
assert base_width == 64, 'BasicBlock doest not support changing base width'
|
assert base_width == 64, 'BasicBlock doest not support changing base width'
|
||||||
first_planes = planes // reduce_first
|
first_planes = planes // reduce_first
|
||||||
outplanes = planes * self.expansion
|
outplanes = planes * self.expansion
|
||||||
|
first_dilation = first_dilation or dilation
|
||||||
|
|
||||||
_selective_first = True # FIXME temporary, for experiments
|
_selective_first = True # FIXME temporary, for experiments
|
||||||
if _selective_first:
|
if _selective_first:
|
||||||
self.conv1 = SelectiveKernelConv(
|
self.conv1 = SelectiveKernelConv(
|
||||||
inplanes, first_planes, stride=stride, dilation=dilation, **sk_kwargs)
|
inplanes, first_planes, stride=stride, dilation=first_dilation, **sk_kwargs)
|
||||||
else:
|
else:
|
||||||
self.conv1 = nn.Conv2d(
|
self.conv1 = nn.Conv2d(
|
||||||
inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation,
|
inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
|
||||||
dilation=dilation, bias=False)
|
dilation=first_dilation, bias=False)
|
||||||
self.bn1 = norm_layer(first_planes)
|
self.bn1 = norm_layer(first_planes)
|
||||||
self.act1 = act_layer(inplace=True)
|
self.act1 = act_layer(inplace=True)
|
||||||
if _selective_first:
|
if _selective_first:
|
||||||
self.conv2 = nn.Conv2d(
|
self.conv2 = nn.Conv2d(
|
||||||
first_planes, outplanes, kernel_size=3, padding=previous_dilation,
|
first_planes, outplanes, kernel_size=3, padding=dilation,
|
||||||
dilation=previous_dilation, bias=False)
|
dilation=dilation, bias=False)
|
||||||
else:
|
else:
|
||||||
self.conv2 = SelectiveKernelConv(
|
self.conv2 = SelectiveKernelConv(
|
||||||
first_planes, outplanes, dilation=previous_dilation, **sk_kwargs)
|
first_planes, outplanes, dilation=dilation, **sk_kwargs)
|
||||||
self.bn2 = norm_layer(outplanes)
|
self.bn2 = norm_layer(outplanes)
|
||||||
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
||||||
self.act2 = act_layer(inplace=True)
|
self.act2 = act_layer(inplace=True)
|
||||||
@ -192,19 +91,20 @@ class SelectiveKernelBottleneck(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||||
cardinality=1, base_width=64, use_se=False, sk_kwargs=None,
|
cardinality=1, base_width=64, use_se=False, sk_kwargs=None,
|
||||||
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||||
super(SelectiveKernelBottleneck, self).__init__()
|
super(SelectiveKernelBottleneck, self).__init__()
|
||||||
|
|
||||||
sk_kwargs = sk_kwargs or {}
|
sk_kwargs = sk_kwargs or {}
|
||||||
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
||||||
first_planes = width // reduce_first
|
first_planes = width // reduce_first
|
||||||
outplanes = planes * self.expansion
|
outplanes = planes * self.expansion
|
||||||
|
first_dilation = first_dilation or dilation
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
|
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
|
||||||
self.bn1 = norm_layer(first_planes)
|
self.bn1 = norm_layer(first_planes)
|
||||||
self.act1 = act_layer(inplace=True)
|
self.act1 = act_layer(inplace=True)
|
||||||
self.conv2 = SelectiveKernelConv(
|
self.conv2 = SelectiveKernelConv(
|
||||||
first_planes, width, stride=stride, dilation=dilation, groups=cardinality, **sk_kwargs)
|
first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, **sk_kwargs)
|
||||||
self.bn2 = norm_layer(width)
|
self.bn2 = norm_layer(width)
|
||||||
self.act2 = act_layer(inplace=True)
|
self.act2 = act_layer(inplace=True)
|
||||||
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user