mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
MQA query_strides bugs fix #2237. No padding for avg_pool2d if not 'same', use scale_factor for Upsample.
This commit is contained in:
parent
474c9cf768
commit
2180800646
@ -134,11 +134,15 @@ class MultiQueryAttention2d(nn.Module):
|
||||
self.query = nn.Sequential()
|
||||
if self.has_query_strides:
|
||||
# FIXME dilation
|
||||
self.query.add_module('down_pool', create_pool2d(
|
||||
'avg',
|
||||
kernel_size=self.query_strides,
|
||||
padding=padding,
|
||||
))
|
||||
if padding == 'same':
|
||||
self.query.add_module('down_pool', create_pool2d(
|
||||
'avg',
|
||||
kernel_size=self.query_strides,
|
||||
padding='same',
|
||||
))
|
||||
else:
|
||||
# no pad if not 'same' as kern=stride=even
|
||||
self.query.add_module('down_pool', nn.AvgPool2d(kernel_size=query_strides))
|
||||
self.query.add_module('norm', norm_layer(dim))
|
||||
self.query.add_module('proj', create_conv2d(
|
||||
dim,
|
||||
@ -190,7 +194,7 @@ class MultiQueryAttention2d(nn.Module):
|
||||
|
||||
self.output = nn.Sequential()
|
||||
if self.has_query_strides:
|
||||
self.output.add_module('upsample', nn.Upsample(self.query_strides, mode='bilinear', align_corners=False))
|
||||
self.output.add_module('upsample', nn.Upsample(scale_factor=self.query_strides, mode='bilinear', align_corners=False))
|
||||
self.output.add_module('proj', create_conv2d(
|
||||
self.value_dim * self.num_heads,
|
||||
dim_out,
|
||||
|
Loading…
x
Reference in New Issue
Block a user