MQA query_strides bugs fix #2237. No padding for avg_pool2d if not 'same', use scale_factor for Upsample.

This commit is contained in:
Ross Wightman 2024-07-19 14:26:54 -07:00
parent 474c9cf768
commit 2180800646

View File

@ -134,11 +134,15 @@ class MultiQueryAttention2d(nn.Module):
self.query = nn.Sequential()
if self.has_query_strides:
# FIXME dilation
self.query.add_module('down_pool', create_pool2d(
'avg',
kernel_size=self.query_strides,
padding=padding,
))
if padding == 'same':
self.query.add_module('down_pool', create_pool2d(
'avg',
kernel_size=self.query_strides,
padding='same',
))
else:
# no pad if not 'same' as kern=stride=even
self.query.add_module('down_pool', nn.AvgPool2d(kernel_size=query_strides))
self.query.add_module('norm', norm_layer(dim))
self.query.add_module('proj', create_conv2d(
dim,
@ -190,7 +194,7 @@ class MultiQueryAttention2d(nn.Module):
self.output = nn.Sequential()
if self.has_query_strides:
self.output.add_module('upsample', nn.Upsample(self.query_strides, mode='bilinear', align_corners=False))
self.output.add_module('upsample', nn.Upsample(scale_factor=self.query_strides, mode='bilinear', align_corners=False))
self.output.add_module('proj', create_conv2d(
self.value_dim * self.num_heads,
dim_out,