mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
MQA query_strides bugs fix #2237. No padding for avg_pool2d if not 'same', use scale_factor for Upsample.
This commit is contained in:
parent
474c9cf768
commit
2180800646
@ -134,11 +134,15 @@ class MultiQueryAttention2d(nn.Module):
|
|||||||
self.query = nn.Sequential()
|
self.query = nn.Sequential()
|
||||||
if self.has_query_strides:
|
if self.has_query_strides:
|
||||||
# FIXME dilation
|
# FIXME dilation
|
||||||
|
if padding == 'same':
|
||||||
self.query.add_module('down_pool', create_pool2d(
|
self.query.add_module('down_pool', create_pool2d(
|
||||||
'avg',
|
'avg',
|
||||||
kernel_size=self.query_strides,
|
kernel_size=self.query_strides,
|
||||||
padding=padding,
|
padding='same',
|
||||||
))
|
))
|
||||||
|
else:
|
||||||
|
# no pad if not 'same' as kern=stride=even
|
||||||
|
self.query.add_module('down_pool', nn.AvgPool2d(kernel_size=query_strides))
|
||||||
self.query.add_module('norm', norm_layer(dim))
|
self.query.add_module('norm', norm_layer(dim))
|
||||||
self.query.add_module('proj', create_conv2d(
|
self.query.add_module('proj', create_conv2d(
|
||||||
dim,
|
dim,
|
||||||
@ -190,7 +194,7 @@ class MultiQueryAttention2d(nn.Module):
|
|||||||
|
|
||||||
self.output = nn.Sequential()
|
self.output = nn.Sequential()
|
||||||
if self.has_query_strides:
|
if self.has_query_strides:
|
||||||
self.output.add_module('upsample', nn.Upsample(self.query_strides, mode='bilinear', align_corners=False))
|
self.output.add_module('upsample', nn.Upsample(scale_factor=self.query_strides, mode='bilinear', align_corners=False))
|
||||||
self.output.add_module('proj', create_conv2d(
|
self.output.add_module('proj', create_conv2d(
|
||||||
self.value_dim * self.num_heads,
|
self.value_dim * self.num_heads,
|
||||||
dim_out,
|
dim_out,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user