111 lines
3.1 KiB
Python
111 lines
3.1 KiB
Python
""" Conv2d w/ Same Padding
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from typing import Tuple, Optional
|
|
|
|
from .config import is_exportable, is_scriptable
|
|
from .padding import pad_same, pad_same_arg, get_padding_value
|
|
|
|
|
|
_USE_EXPORT_CONV = False
|
|
|
|
|
|
def conv2d_same(
|
|
x,
|
|
weight: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None,
|
|
stride: Tuple[int, int] = (1, 1),
|
|
padding: Tuple[int, int] = (0, 0),
|
|
dilation: Tuple[int, int] = (1, 1),
|
|
groups: int = 1,
|
|
):
|
|
x = pad_same(x, weight.shape[-2:], stride, dilation)
|
|
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
|
|
|
|
|
|
class Conv2dSame(nn.Conv2d):
|
|
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
groups=1,
|
|
bias=True,
|
|
):
|
|
super(Conv2dSame, self).__init__(
|
|
in_channels, out_channels, kernel_size,
|
|
stride, 0, dilation, groups, bias,
|
|
)
|
|
|
|
def forward(self, x):
|
|
return conv2d_same(
|
|
x, self.weight, self.bias,
|
|
self.stride, self.padding, self.dilation, self.groups,
|
|
)
|
|
|
|
|
|
class Conv2dSameExport(nn.Conv2d):
|
|
""" ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
|
|
|
NOTE: This does not currently work with torch.jit.script
|
|
"""
|
|
|
|
# pylint: disable=unused-argument
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
groups=1,
|
|
bias=True,
|
|
):
|
|
super(Conv2dSameExport, self).__init__(
|
|
in_channels, out_channels, kernel_size,
|
|
stride, 0, dilation, groups, bias,
|
|
)
|
|
self.pad = None
|
|
self.pad_input_size = (0, 0)
|
|
|
|
def forward(self, x):
|
|
input_size = x.size()[-2:]
|
|
if self.pad is None:
|
|
pad_arg = pad_same_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation)
|
|
self.pad = nn.ZeroPad2d(pad_arg)
|
|
self.pad_input_size = input_size
|
|
|
|
x = self.pad(x)
|
|
return F.conv2d(
|
|
x, self.weight, self.bias,
|
|
self.stride, self.padding, self.dilation, self.groups,
|
|
)
|
|
|
|
|
|
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
|
padding = kwargs.pop('padding', '')
|
|
kwargs.setdefault('bias', False)
|
|
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
|
|
if is_dynamic:
|
|
if _USE_EXPORT_CONV and is_exportable():
|
|
# older PyTorch ver needed this to export same padding reasonably
|
|
assert not is_scriptable() # Conv2DSameExport does not work with jit
|
|
return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs)
|
|
else:
|
|
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
|
|
else:
|
|
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
|
|
|
|