mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge branch 'blur' of https://github.com/VRandme/pytorch-image-models into VRandme-blur
This commit is contained in:
commit
9590f301a9
@ -18,3 +18,4 @@ from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
|||||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||||
from .anti_aliasing import AntiAliasDownsampleLayer
|
from .anti_aliasing import AntiAliasDownsampleLayer
|
||||||
from .space_to_depth import SpaceToDepthModule
|
from .space_to_depth import SpaceToDepthModule
|
||||||
|
from .blurpool import BlurPool2d
|
||||||
|
55
timm/models/layers/blurpool.py
Normal file
55
timm/models/layers/blurpool.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
"""
|
||||||
|
BlurPool layer inspired by
|
||||||
|
- Kornia's Max_BlurPool2d
|
||||||
|
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
|
||||||
|
|
||||||
|
Hacked together by Chris Ha and Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
from .padding import get_padding
|
||||||
|
|
||||||
|
|
||||||
|
class BlurPool2d(nn.Module):
|
||||||
|
r"""Creates a module that computes blurs and downsample a given feature map.
|
||||||
|
See :cite:`zhang2019shiftinvar` for more details.
|
||||||
|
Corresponds to the Downsample class, which does blurring and subsampling
|
||||||
|
Args:
|
||||||
|
channels = Number of input channels
|
||||||
|
blur_filter_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
|
||||||
|
stride (int): downsampling filter stride
|
||||||
|
Shape:
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: the transformed tensor.
|
||||||
|
Examples:
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels, blur_filter_size=3, stride=2) -> None:
|
||||||
|
super(BlurPool2d, self).__init__()
|
||||||
|
assert blur_filter_size > 1
|
||||||
|
self.channels = channels
|
||||||
|
self.blur_filter_size = blur_filter_size
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
pad_size = [get_padding(blur_filter_size, stride, dilation=1)] * 4
|
||||||
|
self.padding = nn.ReflectionPad2d(pad_size)
|
||||||
|
|
||||||
|
blur_matrix = (np.poly1d((0.5, 0.5)) ** (blur_filter_size - 1)).coeffs
|
||||||
|
blur_filter = torch.Tensor(blur_matrix[:, None] * blur_matrix[None, :])
|
||||||
|
self.blur_filter = blur_filter[None, None, :, :]
|
||||||
|
|
||||||
|
def _apply(self, fn):
|
||||||
|
# override nn.Module _apply to prevent need for blur_filter to be registered as a buffer,
|
||||||
|
# this keeps it out of state dict, but allows .cuda(), .type(), etc to work as expected
|
||||||
|
super(BlurPool2d, self)._apply(fn)
|
||||||
|
self.blur_filter = fn(self.blur_filter)
|
||||||
|
|
||||||
|
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
|
||||||
|
C = input_tensor.shape[1]
|
||||||
|
return F.conv2d(
|
||||||
|
self.padding(input_tensor),
|
||||||
|
self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1), stride=self.stride, groups=C)
|
@ -12,7 +12,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .helpers import load_pretrained, adapt_model_from_file
|
from .helpers import load_pretrained, adapt_model_from_file
|
||||||
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn
|
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
|
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
|
||||||
@ -118,6 +118,8 @@ default_cfgs = {
|
|||||||
'ecaresnet101d_pruned': _cfg(
|
'ecaresnet101d_pruned': _cfg(
|
||||||
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
|
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
|
'resnetblur18': _cfg(),
|
||||||
|
'resnetblur50': _cfg()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -131,7 +133,7 @@ class BasicBlock(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
||||||
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
||||||
attn_layer=None, drop_block=None, drop_path=None):
|
attn_layer=None, drop_block=None, drop_path=None, blur=False):
|
||||||
super(BasicBlock, self).__init__()
|
super(BasicBlock, self).__init__()
|
||||||
|
|
||||||
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
||||||
@ -141,10 +143,12 @@ class BasicBlock(nn.Module):
|
|||||||
first_dilation = first_dilation or dilation
|
first_dilation = first_dilation or dilation
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(
|
self.conv1 = nn.Conv2d(
|
||||||
inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
|
inplanes, first_planes, kernel_size=3, stride=1 if blur else stride, padding=first_dilation,
|
||||||
dilation=first_dilation, bias=False)
|
dilation=first_dilation, bias=False)
|
||||||
self.bn1 = norm_layer(first_planes)
|
self.bn1 = norm_layer(first_planes)
|
||||||
self.act1 = act_layer(inplace=True)
|
self.act1 = act_layer(inplace=True)
|
||||||
|
self.blurpool = BlurPool2d(channels=first_planes) if stride == 2 and blur else None
|
||||||
|
|
||||||
self.conv2 = nn.Conv2d(
|
self.conv2 = nn.Conv2d(
|
||||||
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
|
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
|
||||||
self.bn2 = norm_layer(outplanes)
|
self.bn2 = norm_layer(outplanes)
|
||||||
@ -169,6 +173,8 @@ class BasicBlock(nn.Module):
|
|||||||
if self.drop_block is not None:
|
if self.drop_block is not None:
|
||||||
x = self.drop_block(x)
|
x = self.drop_block(x)
|
||||||
x = self.act1(x)
|
x = self.act1(x)
|
||||||
|
if self.blurpool is not None:
|
||||||
|
x = self.blurpool(x)
|
||||||
|
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = self.bn2(x)
|
x = self.bn2(x)
|
||||||
@ -195,22 +201,26 @@ class Bottleneck(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
||||||
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
||||||
attn_layer=None, drop_block=None, drop_path=None):
|
attn_layer=None, drop_block=None, drop_path=None, blur=False):
|
||||||
super(Bottleneck, self).__init__()
|
super(Bottleneck, self).__init__()
|
||||||
|
|
||||||
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
||||||
first_planes = width // reduce_first
|
first_planes = width // reduce_first
|
||||||
outplanes = planes * self.expansion
|
outplanes = planes * self.expansion
|
||||||
first_dilation = first_dilation or dilation
|
first_dilation = first_dilation or dilation
|
||||||
|
self.blur = blur
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
|
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
|
||||||
self.bn1 = norm_layer(first_planes)
|
self.bn1 = norm_layer(first_planes)
|
||||||
self.act1 = act_layer(inplace=True)
|
self.act1 = act_layer(inplace=True)
|
||||||
|
|
||||||
self.conv2 = nn.Conv2d(
|
self.conv2 = nn.Conv2d(
|
||||||
first_planes, width, kernel_size=3, stride=stride,
|
first_planes, width, kernel_size=3, stride=1 if blur else stride,
|
||||||
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
|
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
|
||||||
self.bn2 = norm_layer(width)
|
self.bn2 = norm_layer(width)
|
||||||
self.act2 = act_layer(inplace=True)
|
self.act2 = act_layer(inplace=True)
|
||||||
|
self.blurpool = BlurPool2d(channels=width) if stride == 2 and blur else None
|
||||||
|
|
||||||
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
||||||
self.bn3 = norm_layer(outplanes)
|
self.bn3 = norm_layer(outplanes)
|
||||||
|
|
||||||
@ -240,6 +250,8 @@ class Bottleneck(nn.Module):
|
|||||||
if self.drop_block is not None:
|
if self.drop_block is not None:
|
||||||
x = self.drop_block(x)
|
x = self.drop_block(x)
|
||||||
x = self.act2(x)
|
x = self.act2(x)
|
||||||
|
if self.blurpool is not None:
|
||||||
|
x = self.blurpool(x)
|
||||||
|
|
||||||
x = self.conv3(x)
|
x = self.conv3(x)
|
||||||
x = self.bn3(x)
|
x = self.bn3(x)
|
||||||
@ -359,12 +371,19 @@ class ResNet(nn.Module):
|
|||||||
Dropout probability before classifier, for training
|
Dropout probability before classifier, for training
|
||||||
global_pool : str, default 'avg'
|
global_pool : str, default 'avg'
|
||||||
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
|
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
|
||||||
|
blur : str, default ''
|
||||||
|
Location of Blurring:
|
||||||
|
* '', default - Not applied
|
||||||
|
* 'max' - only stem layer MaxPool will be blurred
|
||||||
|
* 'strided' - only strided convolutions in the downsampling blocks (assembled-cnn style)
|
||||||
|
* 'max_strided' - on both stem MaxPool and strided convolutions (zhang2019shiftinvar style for ResNets)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, block, layers, num_classes=1000, in_chans=3,
|
def __init__(self, block, layers, num_classes=1000, in_chans=3,
|
||||||
cardinality=1, base_width=64, stem_width=64, stem_type='',
|
cardinality=1, base_width=64, stem_width=64, stem_type='',
|
||||||
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
|
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
|
||||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0.,
|
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0.,
|
||||||
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
|
drop_block_rate=0., global_pool='avg', blur='', zero_init_last_bn=True, block_args=None):
|
||||||
block_args = block_args or dict()
|
block_args = block_args or dict()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
deep_stem = 'deep' in stem_type
|
deep_stem = 'deep' in stem_type
|
||||||
@ -373,6 +392,7 @@ class ResNet(nn.Module):
|
|||||||
self.base_width = base_width
|
self.base_width = base_width
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
self.expansion = block.expansion
|
self.expansion = block.expansion
|
||||||
|
self.blur = 'strided' in blur
|
||||||
super(ResNet, self).__init__()
|
super(ResNet, self).__init__()
|
||||||
|
|
||||||
# Stem
|
# Stem
|
||||||
@ -393,6 +413,13 @@ class ResNet(nn.Module):
|
|||||||
self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
||||||
self.bn1 = norm_layer(self.inplanes)
|
self.bn1 = norm_layer(self.inplanes)
|
||||||
self.act1 = act_layer(inplace=True)
|
self.act1 = act_layer(inplace=True)
|
||||||
|
# Stem Pooling
|
||||||
|
if 'max' in blur :
|
||||||
|
self.maxpool = nn.Sequential(*[
|
||||||
|
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
|
||||||
|
BlurPool2d(channels=self.inplanes, stride=2)
|
||||||
|
])
|
||||||
|
else :
|
||||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
# Feature Blocks
|
# Feature Blocks
|
||||||
@ -445,7 +472,7 @@ class ResNet(nn.Module):
|
|||||||
|
|
||||||
block_kwargs = dict(
|
block_kwargs = dict(
|
||||||
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
|
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
|
||||||
dilation=dilation, **kwargs)
|
dilation=dilation, blur=self.blur, **kwargs)
|
||||||
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)]
|
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)]
|
||||||
self.inplanes = planes * block.expansion
|
self.inplanes = planes * block.expansion
|
||||||
layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)]
|
layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)]
|
||||||
@ -1114,3 +1141,26 @@ def ecaresnet101d_pruned(pretrained=False, num_classes=1000, in_chans=3, **kwarg
|
|||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
|
"""Constructs a ResNet-18 model with blur anti-aliasing
|
||||||
|
"""
|
||||||
|
default_cfg = default_cfgs['resnetblur18']
|
||||||
|
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, blur='max_strided',**kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
|
"""Constructs a ResNet-50 model with blur anti-aliasing
|
||||||
|
"""
|
||||||
|
default_cfg = default_cfgs['resnetblur50']
|
||||||
|
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='max_strided', **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user