Update comments for Selective Kernel and DropBlock/Path impl, add skresnet34 weights
parent
569419b38d
commit
f1d5f8a6c4
|
@ -12,6 +12,6 @@ from .eca import EcaModule, CecaModule
|
|||
from .activations import *
|
||||
from .adaptive_avgmax_pool import \
|
||||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||
from .drop import DropBlock2d, DropPath
|
||||
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||
|
|
|
@ -2,6 +2,16 @@
|
|||
|
||||
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
|
||||
|
||||
Papers:
|
||||
DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
|
||||
|
||||
Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
|
||||
|
||||
Code:
|
||||
DropBlock impl inspired by two Tensorflow impl that I liked:
|
||||
- https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
|
||||
- https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
|
@ -11,9 +21,15 @@ import numpy as np
|
|||
import math
|
||||
|
||||
|
||||
def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noise=False):
|
||||
def drop_block_2d(x, drop_prob=0.1, training=False, block_size=7, gamma_scale=1.0, drop_with_noise=False):
|
||||
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||
|
||||
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
|
||||
runs with success, but needs further validation and possibly optimization for lower runtime impact.
|
||||
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
_, _, height, width = x.shape
|
||||
total_size = width * height
|
||||
clipped_block_size = min(block_size, min(width, height))
|
||||
|
@ -60,14 +76,21 @@ class DropBlock2d(nn.Module):
|
|||
self.with_noise = with_noise
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training or not self.drop_prob:
|
||||
return x
|
||||
return drop_block_2d(x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise)
|
||||
return drop_block_2d(x, self.drop_prob, self.training, self.block_size, self.gamma_scale, self.with_noise)
|
||||
|
||||
|
||||
def drop_path(x, drop_prob=0.):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in residual blocks).
|
||||
def drop_path(x, drop_prob=0., training=False):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
|
@ -76,13 +99,11 @@ def drop_path(x, drop_prob=0.):
|
|||
|
||||
|
||||
class DropPath(nn.ModuleDict):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in residual blocks).
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training or not self.drop_prob:
|
||||
return x
|
||||
return drop_path(x, self.drop_prob)
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
|
|
@ -21,6 +21,11 @@ def _kernel_valid(k):
|
|||
class SelectiveKernelAttn(nn.Module):
|
||||
def __init__(self, channels, num_paths=2, attn_channels=32,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
""" Selective Kernel Attention Module
|
||||
|
||||
Selective Kernel attention mechanism factored out into its own module.
|
||||
|
||||
"""
|
||||
super(SelectiveKernelAttn, self).__init__()
|
||||
self.num_paths = num_paths
|
||||
self.pool = nn.AdaptiveAvgPool2d(1)
|
||||
|
@ -48,8 +53,33 @@ class SelectiveKernelConv(nn.Module):
|
|||
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
|
||||
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
|
||||
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
""" Selective Kernel Convolution Module
|
||||
|
||||
As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications.
|
||||
|
||||
Largest change is the input split, which divides the input channels across each convolution path, this can
|
||||
be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps
|
||||
the parameter count from ballooning when the convolutions themselves don't have groups, but still provides
|
||||
a noteworthy increase in performance over similar param count models without this attention layer. -Ross W
|
||||
|
||||
Args:
|
||||
in_channels (int): module input (feature) channel count
|
||||
out_channels (int): module output (feature) channel count
|
||||
kernel_size (int, list): kernel size for each convolution branch
|
||||
stride (int): stride for convolutions
|
||||
dilation (int): dilation for module as a whole, impacts dilation of each branch
|
||||
groups (int): number of groups for each branch
|
||||
attn_reduction (int, float): reduction factor for attention features
|
||||
min_attn_channels (int): minimum attention feature channels
|
||||
keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
|
||||
split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
|
||||
can be viewed as grouping by path, output expands to module out_channels count
|
||||
drop_block (nn.Module): drop block module
|
||||
act_layer (nn.Module): activation layer to use
|
||||
norm_layer (nn.Module): batchnorm/norm layer to use
|
||||
"""
|
||||
super(SelectiveKernelConv, self).__init__()
|
||||
kernel_size = kernel_size or [3, 5]
|
||||
kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation
|
||||
_kernel_valid(kernel_size)
|
||||
if not isinstance(kernel_size, list):
|
||||
kernel_size = [kernel_size] * 2
|
||||
|
|
|
@ -382,7 +382,7 @@ class ResNet(nn.Module):
|
|||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
# Feature Blocks
|
||||
dp = DropPath(drop_path_rate) if drop_block_rate else None
|
||||
dp = DropPath(drop_path_rate) if drop_path_rate else None
|
||||
db_3 = DropBlock2d(drop_block_rate, 7, 0.25) if drop_block_rate else None
|
||||
db_4 = DropBlock2d(drop_block_rate, 7, 1.00) if drop_block_rate else None
|
||||
channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4
|
||||
|
|
|
@ -2,6 +2,10 @@
|
|||
|
||||
Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
|
||||
|
||||
This was inspired by reading 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268)
|
||||
and a streamlined impl at https://github.com/clovaai/assembled-cnn but I ended up building something closer
|
||||
to the original paper with some modifications of my own to better balance param count vs accuracy.
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
import math
|
||||
|
@ -29,7 +33,8 @@ def _cfg(url='', **kwargs):
|
|||
default_cfgs = {
|
||||
'skresnet18': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth'),
|
||||
'skresnet34': _cfg(url=''),
|
||||
'skresnet34': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth'),
|
||||
'skresnet50': _cfg(),
|
||||
'skresnet50d': _cfg(),
|
||||
'skresnext50_32x4d': _cfg(
|
||||
|
|
Loading…
Reference in New Issue