Merge pull request #2238 from huggingface/fix_mnv4_query_strides
Fix mnv4 query stridestiny_test_models^2
commit
a1996ec0f4
|
@ -134,11 +134,15 @@ class MultiQueryAttention2d(nn.Module):
|
|||
self.query = nn.Sequential()
|
||||
if self.has_query_strides:
|
||||
# FIXME dilation
|
||||
self.query.add_module('down_pool', create_pool2d(
|
||||
'avg',
|
||||
kernel_size=self.query_strides,
|
||||
padding=padding,
|
||||
))
|
||||
if padding == 'same':
|
||||
self.query.add_module('down_pool', create_pool2d(
|
||||
'avg',
|
||||
kernel_size=self.query_strides,
|
||||
padding='same',
|
||||
))
|
||||
else:
|
||||
# no pad if not 'same' as kern=stride=even
|
||||
self.query.add_module('down_pool', nn.AvgPool2d(kernel_size=query_strides))
|
||||
self.query.add_module('norm', norm_layer(dim))
|
||||
self.query.add_module('proj', create_conv2d(
|
||||
dim,
|
||||
|
@ -190,7 +194,7 @@ class MultiQueryAttention2d(nn.Module):
|
|||
|
||||
self.output = nn.Sequential()
|
||||
if self.has_query_strides:
|
||||
self.output.add_module('upsample', nn.Upsample(self.query_strides, mode='bilinear', align_corners=False))
|
||||
self.output.add_module('upsample', nn.Upsample(scale_factor=self.query_strides, mode='bilinear', align_corners=False))
|
||||
self.output.add_module('proj', create_conv2d(
|
||||
self.value_dim * self.num_heads,
|
||||
dim_out,
|
||||
|
|
|
@ -3,14 +3,19 @@
|
|||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .helpers import to_2tuple
|
||||
|
||||
|
||||
# Calculate symmetric padding for a convolution
|
||||
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
|
||||
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> Union[int, List[int]]:
|
||||
if any([isinstance(v, (tuple, list)) for v in [kernel_size, stride, dilation]]):
|
||||
kernel_size, stride, dilation = to_2tuple(kernel_size), to_2tuple(stride), to_2tuple(dilation)
|
||||
return [get_padding(*a) for a in zip(kernel_size, stride, dilation)]
|
||||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||
return padding
|
||||
|
||||
|
@ -25,6 +30,9 @@ def get_same_padding(x: int, kernel_size: int, stride: int, dilation: int):
|
|||
|
||||
# Can SAME padding for given args be done statically?
|
||||
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
|
||||
if any([isinstance(v, (tuple, list)) for v in [kernel_size, stride, dilation]]):
|
||||
kernel_size, stride, dilation = to_2tuple(kernel_size), to_2tuple(stride), to_2tuple(dilation)
|
||||
return all([is_static_pad(*a) for a in zip(kernel_size, stride, dilation)])
|
||||
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue