Some cleanup and fixes for initial BlurPool impl. Still some testing and tweaks to go...
parent
acd1b6cccd
commit
6cdeca24a3
|
@ -1,14 +1,17 @@
|
|||
'''
|
||||
"""
|
||||
BlurPool layer inspired by
|
||||
Kornia's Max_BlurPool2d
|
||||
and
|
||||
Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
|
||||
- Kornia's Max_BlurPool2d
|
||||
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
|
||||
|
||||
Hacked together by Chris Ha and Ross Wightman
|
||||
"""
|
||||
|
||||
'''
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from .padding import get_padding
|
||||
|
||||
|
||||
class BlurPool2d(nn.Module):
|
||||
|
@ -25,30 +28,30 @@ class BlurPool2d(nn.Module):
|
|||
Examples:
|
||||
"""
|
||||
|
||||
def __init__(self, channels=None, blur_filter_size=3, stride=2) -> None:
|
||||
def __init__(self, channels, blur_filter_size=3, stride=2) -> None:
|
||||
super(BlurPool2d, self).__init__()
|
||||
assert blur_filter_size in [3, 5]
|
||||
assert blur_filter_size > 1
|
||||
self.channels = channels
|
||||
self.blur_filter_size = blur_filter_size
|
||||
self.stride = stride
|
||||
|
||||
if blur_filter_size == 3:
|
||||
pad_size = [1] * 4
|
||||
blur_matrix = torch.Tensor([[1., 2., 1]]) / 4 # binomial filter b2
|
||||
else:
|
||||
pad_size = [2] * 4
|
||||
blur_matrix = torch.Tensor([[1., 4., 6., 4., 1.]]) / 16 # binomial filter b4
|
||||
|
||||
pad_size = [get_padding(blur_filter_size, stride, dilation=1)] * 4
|
||||
self.padding = nn.ReflectionPad2d(pad_size)
|
||||
blur_filter = blur_matrix * blur_matrix.T
|
||||
|
||||
blur_matrix = (np.poly1d((0.5, 0.5)) ** (blur_filter_size - 1)).coeffs
|
||||
blur_filter = torch.Tensor(blur_matrix[:, None] * blur_matrix[None, :])
|
||||
# FIXME figure a clean hack to prevent the filter from getting saved in weights, but still
|
||||
# plays nice with recursive module apply for fn like .cuda(), .type(), etc -RW
|
||||
self.register_buffer('blur_filter', blur_filter[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
||||
|
||||
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
|
||||
if not torch.is_tensor(input_tensor):
|
||||
raise TypeError("Input input type is not a torch.Tensor. Got {}"
|
||||
.format(type(input_tensor)))
|
||||
raise TypeError("Input input type is not a torch.Tensor. Got {}".format(type(input_tensor)))
|
||||
if not len(input_tensor.shape) == 4:
|
||||
raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}"
|
||||
.format(input_tensor.shape))
|
||||
raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}".format(input_tensor.shape))
|
||||
# apply blur_filter on input
|
||||
return F.conv2d(self.padding(input_tensor), self.blur_filter, stride=self.stride, groups=input_tensor.shape[1])
|
||||
return F.conv2d(
|
||||
self.padding(input_tensor),
|
||||
self.blur_filter.type(input_tensor.dtype),
|
||||
stride=self.stride,
|
||||
groups=input_tensor.shape[1])
|
||||
|
|
|
@ -127,21 +127,14 @@ class BasicBlock(nn.Module):
|
|||
first_planes = planes // reduce_first
|
||||
outplanes = planes * self.expansion
|
||||
first_dilation = first_dilation or dilation
|
||||
self.blur = blur
|
||||
|
||||
if blur and stride==2:
|
||||
self.conv1 = nn.Conv2d(
|
||||
inplanes, first_planes, kernel_size=3, stride=1, padding=first_dilation,
|
||||
dilation=first_dilation, bias=False)
|
||||
self.blurpool=BlurPool2d(channels=first_planes)
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(
|
||||
inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
|
||||
self.conv1 = nn.Conv2d(
|
||||
inplanes, first_planes, kernel_size=3, stride=1 if blur else stride, padding=first_dilation,
|
||||
dilation=first_dilation, bias=False)
|
||||
self.blurpool = None
|
||||
|
||||
self.bn1 = norm_layer(first_planes)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.blurpool = BlurPool2d(channels=first_planes) if stride == 2 and blur else None
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
|
||||
self.bn2 = norm_layer(outplanes)
|
||||
|
@ -165,11 +158,9 @@ class BasicBlock(nn.Module):
|
|||
x = self.bn1(x)
|
||||
if self.drop_block is not None:
|
||||
x = self.drop_block(x)
|
||||
x = self.act1(x)
|
||||
if self.blurpool is not None:
|
||||
x = self.act1(x)
|
||||
x = self.blurpool(x)
|
||||
else:
|
||||
x = self.act1(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
|
@ -209,19 +200,13 @@ class Bottleneck(nn.Module):
|
|||
self.bn1 = norm_layer(first_planes)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
if blur and stride==2:
|
||||
self.conv2 = nn.Conv2d(
|
||||
first_planes, width, kernel_size=3, stride=1,
|
||||
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
|
||||
self.blurpool = BlurPool2d(channels=width)
|
||||
else:
|
||||
self.conv2 = nn.Conv2d(
|
||||
first_planes, width, kernel_size=3, stride=stride,
|
||||
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
|
||||
self.blurpool = None
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
first_planes, width, kernel_size=3, stride=1 if blur else stride,
|
||||
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
self.blurpool = BlurPool2d(channels=width) if stride == 2 and blur else None
|
||||
|
||||
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
||||
self.bn3 = norm_layer(outplanes)
|
||||
|
||||
|
@ -251,6 +236,8 @@ class Bottleneck(nn.Module):
|
|||
if self.drop_block is not None:
|
||||
x = self.drop_block(x)
|
||||
x = self.act2(x)
|
||||
if self.blurpool is not None:
|
||||
x = self.blurpool(x)
|
||||
|
||||
x = self.conv3(x)
|
||||
x = self.bn3(x)
|
||||
|
@ -412,11 +399,12 @@ class ResNet(nn.Module):
|
|||
self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
# Stem Blur
|
||||
# Stem Pooling
|
||||
if 'max' in blur :
|
||||
self.maxpool = nn.Sequential(*[
|
||||
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
|
||||
BlurPool2d(channels=self.inplanes)])
|
||||
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
|
||||
BlurPool2d(channels=self.inplanes, stride=2)
|
||||
])
|
||||
else :
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
|
@ -470,8 +458,8 @@ class ResNet(nn.Module):
|
|||
|
||||
block_kwargs = dict(
|
||||
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
|
||||
dilation=dilation, **kwargs)
|
||||
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, blur=self.blur, **block_kwargs)]
|
||||
dilation=dilation, blur=self.blur, **kwargs)
|
||||
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)]
|
||||
self.inplanes = planes * block.expansion
|
||||
layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)]
|
||||
|
||||
|
@ -1075,7 +1063,7 @@ def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|||
def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-50 model. With assembled-cnn style blur
|
||||
"""
|
||||
default_cfg = default_cfgs['resnetblur18']
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='strided', **kwargs)
|
||||
default_cfg = default_cfgs['resnetblur50']
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='max_strided', **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
return model
|
Loading…
Reference in New Issue