Merge pull request #2238 from huggingface/fix_mnv4_query_strides

Fix mnv4 query strides
tiny_test_models^2
Ross Wightman 2024-07-19 16:32:08 -07:00 committed by GitHub
commit a1996ec0f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 8 deletions

View File

@ -134,11 +134,15 @@ class MultiQueryAttention2d(nn.Module):
self.query = nn.Sequential()
if self.has_query_strides:
# FIXME dilation
self.query.add_module('down_pool', create_pool2d(
'avg',
kernel_size=self.query_strides,
padding=padding,
))
if padding == 'same':
self.query.add_module('down_pool', create_pool2d(
'avg',
kernel_size=self.query_strides,
padding='same',
))
else:
# no pad if not 'same' as kern=stride=even
self.query.add_module('down_pool', nn.AvgPool2d(kernel_size=query_strides))
self.query.add_module('norm', norm_layer(dim))
self.query.add_module('proj', create_conv2d(
dim,
@ -190,7 +194,7 @@ class MultiQueryAttention2d(nn.Module):
self.output = nn.Sequential()
if self.has_query_strides:
self.output.add_module('upsample', nn.Upsample(self.query_strides, mode='bilinear', align_corners=False))
self.output.add_module('upsample', nn.Upsample(scale_factor=self.query_strides, mode='bilinear', align_corners=False))
self.output.add_module('proj', create_conv2d(
self.value_dim * self.num_heads,
dim_out,

View File

@ -3,14 +3,19 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import math
from typing import List, Tuple
from typing import List, Tuple, Union
import torch
import torch.nn.functional as F
from .helpers import to_2tuple
# Calculate symmetric padding for a convolution
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> Union[int, List[int]]:
if any([isinstance(v, (tuple, list)) for v in [kernel_size, stride, dilation]]):
kernel_size, stride, dilation = to_2tuple(kernel_size), to_2tuple(stride), to_2tuple(dilation)
return [get_padding(*a) for a in zip(kernel_size, stride, dilation)]
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
@ -25,6 +30,9 @@ def get_same_padding(x: int, kernel_size: int, stride: int, dilation: int):
# Can SAME padding for given args be done statically?
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
if any([isinstance(v, (tuple, list)) for v in [kernel_size, stride, dilation]]):
kernel_size, stride, dilation = to_2tuple(kernel_size), to_2tuple(stride), to_2tuple(dilation)
return all([is_static_pad(*a) for a in zip(kernel_size, stride, dilation)])
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0