mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update blurpool.py
clean up code for PR
This commit is contained in:
parent
3a287a6e76
commit
ce3d82b58b
@ -1,7 +1,7 @@
|
|||||||
'''independent attempt to implement
|
'''
|
||||||
|
BlurPool layer inspired by
|
||||||
MaxBlurPool2d in a more general fashion(separate maxpooling from BlurPool)
|
Kornia's Max_BlurPool2d
|
||||||
which was again inspired by
|
and
|
||||||
Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
|
Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
|
||||||
|
|
||||||
'''
|
'''
|
||||||
@ -17,8 +17,7 @@ class BlurPool2d(nn.Module):
|
|||||||
Corresponds to the Downsample class, which does blurring and subsampling
|
Corresponds to the Downsample class, which does blurring and subsampling
|
||||||
Args:
|
Args:
|
||||||
channels = Number of input channels
|
channels = Number of input channels
|
||||||
blur_filter_size (int): filter size for blurring. currently supports either 3 or 5 (most common)
|
blur_filter_size (int): binomial filter size for blurring. currently supports 3(default) and 5.
|
||||||
defaults to 3.
|
|
||||||
stride (int): downsampling filter stride
|
stride (int): downsampling filter stride
|
||||||
Shape:
|
Shape:
|
||||||
Returns:
|
Returns:
|
||||||
@ -35,16 +34,16 @@ class BlurPool2d(nn.Module):
|
|||||||
|
|
||||||
if blur_filter_size == 3:
|
if blur_filter_size == 3:
|
||||||
pad_size = [1] * 4
|
pad_size = [1] * 4
|
||||||
blur_matrix = torch.Tensor([[1., 2., 1]]) / 4 # binomial kernel b2
|
blur_matrix = torch.Tensor([[1., 2., 1]]) / 4 # binomial filter b2
|
||||||
else:
|
else:
|
||||||
pad_size = [2] * 4
|
pad_size = [2] * 4
|
||||||
blur_matrix = torch.Tensor([[1., 4., 6., 4., 1.]]) / 16 # binomial filter kernel b4
|
blur_matrix = torch.Tensor([[1., 4., 6., 4., 1.]]) / 16 # binomial filter b4
|
||||||
|
|
||||||
self.padding = nn.ReflectionPad2d(pad_size)
|
self.padding = nn.ReflectionPad2d(pad_size)
|
||||||
blur_filter = blur_matrix * blur_matrix.T
|
blur_filter = blur_matrix * blur_matrix.T
|
||||||
self.register_buffer('blur_filter', blur_filter[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
self.register_buffer('blur_filter', blur_filter[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
||||||
|
|
||||||
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
|
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
|
||||||
if not torch.is_tensor(input_tensor):
|
if not torch.is_tensor(input_tensor):
|
||||||
raise TypeError("Input input type is not a torch.Tensor. Got {}"
|
raise TypeError("Input input type is not a torch.Tensor. Got {}"
|
||||||
.format(type(input_tensor)))
|
.format(type(input_tensor)))
|
||||||
@ -53,16 +52,3 @@ class BlurPool2d(nn.Module):
|
|||||||
.format(input_tensor.shape))
|
.format(input_tensor.shape))
|
||||||
# apply blur_filter on input
|
# apply blur_filter on input
|
||||||
return F.conv2d(self.padding(input_tensor), self.blur_filter, stride=self.stride, groups=input_tensor.shape[1])
|
return F.conv2d(self.padding(input_tensor), self.blur_filter, stride=self.stride, groups=input_tensor.shape[1])
|
||||||
|
|
||||||
|
|
||||||
######################
|
|
||||||
# functional interface
|
|
||||||
######################
|
|
||||||
|
|
||||||
|
|
||||||
'''def blur_pool2d() -> torch.Tensor:
|
|
||||||
r"""Creates a module that computes pools and blurs and downsample a given
|
|
||||||
feature map.
|
|
||||||
See :class:`~kornia.contrib.MaxBlurPool2d` for details.
|
|
||||||
"""
|
|
||||||
return BlurPool2d(kernel_size, ceil_mode)(input)'''
|
|
Loading…
x
Reference in New Issue
Block a user