Implement Functional Blur on resnet.py
1. add ResNet argument blur='' 2. implement blur for maxpool and strided convs in downsampling blockspull/101/head
parent
ce3d82b58b
commit
acd1b6cccd
|
@ -15,3 +15,4 @@ from .adaptive_avgmax_pool import \
|
|||
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||
from .blurpool import BlurPool2d
|
||||
|
|
|
@ -17,7 +17,7 @@ class BlurPool2d(nn.Module):
|
|||
Corresponds to the Downsample class, which does blurring and subsampling
|
||||
Args:
|
||||
channels = Number of input channels
|
||||
blur_filter_size (int): binomial filter size for blurring. currently supports 3(default) and 5.
|
||||
blur_filter_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
|
||||
stride (int): downsampling filter stride
|
||||
Shape:
|
||||
Returns:
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch.nn.functional as F
|
|||
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn
|
||||
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
|
||||
|
@ -104,6 +104,8 @@ default_cfgs = {
|
|||
interpolation='bicubic'),
|
||||
'ecaresnet18': _cfg(),
|
||||
'ecaresnet50': _cfg(),
|
||||
'resnetblur18': _cfg(),
|
||||
'resnetblur50': _cfg()
|
||||
}
|
||||
|
||||
|
||||
|
@ -117,7 +119,7 @@ class BasicBlock(nn.Module):
|
|||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
||||
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
||||
attn_layer=None, drop_block=None, drop_path=None):
|
||||
attn_layer=None, drop_block=None, drop_path=None, blur=False):
|
||||
super(BasicBlock, self).__init__()
|
||||
|
||||
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
||||
|
@ -125,10 +127,19 @@ class BasicBlock(nn.Module):
|
|||
first_planes = planes // reduce_first
|
||||
outplanes = planes * self.expansion
|
||||
first_dilation = first_dilation or dilation
|
||||
self.blur = blur
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
if blur and stride==2:
|
||||
self.conv1 = nn.Conv2d(
|
||||
inplanes, first_planes, kernel_size=3, stride=1, padding=first_dilation,
|
||||
dilation=first_dilation, bias=False)
|
||||
self.blurpool=BlurPool2d(channels=first_planes)
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(
|
||||
inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
|
||||
dilation=first_dilation, bias=False)
|
||||
self.blurpool = None
|
||||
|
||||
self.bn1 = norm_layer(first_planes)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.conv2 = nn.Conv2d(
|
||||
|
@ -154,7 +165,11 @@ class BasicBlock(nn.Module):
|
|||
x = self.bn1(x)
|
||||
if self.drop_block is not None:
|
||||
x = self.drop_block(x)
|
||||
x = self.act1(x)
|
||||
if self.blurpool is not None:
|
||||
x = self.act1(x)
|
||||
x = self.blurpool(x)
|
||||
else:
|
||||
x = self.act1(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
|
@ -181,20 +196,30 @@ class Bottleneck(nn.Module):
|
|||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
||||
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
||||
attn_layer=None, drop_block=None, drop_path=None):
|
||||
attn_layer=None, drop_block=None, drop_path=None, blur=False):
|
||||
super(Bottleneck, self).__init__()
|
||||
|
||||
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
||||
first_planes = width // reduce_first
|
||||
outplanes = planes * self.expansion
|
||||
first_dilation = first_dilation or dilation
|
||||
self.blur = blur
|
||||
|
||||
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
|
||||
self.bn1 = norm_layer(first_planes)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.conv2 = nn.Conv2d(
|
||||
first_planes, width, kernel_size=3, stride=stride,
|
||||
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
|
||||
|
||||
if blur and stride==2:
|
||||
self.conv2 = nn.Conv2d(
|
||||
first_planes, width, kernel_size=3, stride=1,
|
||||
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
|
||||
self.blurpool = BlurPool2d(channels=width)
|
||||
else:
|
||||
self.conv2 = nn.Conv2d(
|
||||
first_planes, width, kernel_size=3, stride=stride,
|
||||
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
|
||||
self.blurpool = None
|
||||
|
||||
self.bn2 = norm_layer(width)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
||||
|
@ -345,12 +370,19 @@ class ResNet(nn.Module):
|
|||
Dropout probability before classifier, for training
|
||||
global_pool : str, default 'avg'
|
||||
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
|
||||
blur : str, default ''
|
||||
Location of Blurring:
|
||||
* '', default - Not applied
|
||||
* 'max' - only stem layer MaxPool will be blurred
|
||||
* 'strided' - only strided convolutions in the downsampling blocks (assembled-cnn style)
|
||||
* 'max_strided' - on both stem MaxPool and strided convolutions (zhang2019shiftinvar style for ResNets)
|
||||
|
||||
"""
|
||||
def __init__(self, block, layers, num_classes=1000, in_chans=3,
|
||||
cardinality=1, base_width=64, stem_width=64, stem_type='',
|
||||
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0.,
|
||||
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
|
||||
drop_block_rate=0., global_pool='avg', blur='', zero_init_last_bn=True, block_args=None):
|
||||
block_args = block_args or dict()
|
||||
self.num_classes = num_classes
|
||||
deep_stem = 'deep' in stem_type
|
||||
|
@ -359,6 +391,7 @@ class ResNet(nn.Module):
|
|||
self.base_width = base_width
|
||||
self.drop_rate = drop_rate
|
||||
self.expansion = block.expansion
|
||||
self.blur = 'strided' in blur
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
# Stem
|
||||
|
@ -379,7 +412,13 @@ class ResNet(nn.Module):
|
|||
self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
# Stem Blur
|
||||
if 'max' in blur :
|
||||
self.maxpool = nn.Sequential(*[
|
||||
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
|
||||
BlurPool2d(channels=self.inplanes)])
|
||||
else :
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
# Feature Blocks
|
||||
dp = DropPath(drop_path_rate) if drop_path_rate else None
|
||||
|
@ -432,7 +471,7 @@ class ResNet(nn.Module):
|
|||
block_kwargs = dict(
|
||||
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
|
||||
dilation=dilation, **kwargs)
|
||||
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)]
|
||||
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, blur=self.blur, **block_kwargs)]
|
||||
self.inplanes = planes * block.expansion
|
||||
layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)]
|
||||
|
||||
|
@ -1022,3 +1061,21 @@ def ecaresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
@register_model
|
||||
def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-18 model. With original style blur
|
||||
"""
|
||||
default_cfg = default_cfgs['resnetblur18']
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, blur='max_strided',**kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
return model
|
||||
|
||||
@register_model
|
||||
def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-50 model. With assembled-cnn style blur
|
||||
"""
|
||||
default_cfg = default_cfgs['resnetblur18']
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='strided', **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
return model
|
Loading…
Reference in New Issue