diff --git a/timm/layers/attention2d.py b/timm/layers/attention2d.py index d1d38fb3..4d4a51a7 100644 --- a/timm/layers/attention2d.py +++ b/timm/layers/attention2d.py @@ -134,11 +134,15 @@ class MultiQueryAttention2d(nn.Module): self.query = nn.Sequential() if self.has_query_strides: # FIXME dilation - self.query.add_module('down_pool', create_pool2d( - 'avg', - kernel_size=self.query_strides, - padding=padding, - )) + if padding == 'same': + self.query.add_module('down_pool', create_pool2d( + 'avg', + kernel_size=self.query_strides, + padding='same', + )) + else: + # no pad if not 'same' as kern=stride=even + self.query.add_module('down_pool', nn.AvgPool2d(kernel_size=query_strides)) self.query.add_module('norm', norm_layer(dim)) self.query.add_module('proj', create_conv2d( dim, @@ -190,7 +194,7 @@ class MultiQueryAttention2d(nn.Module): self.output = nn.Sequential() if self.has_query_strides: - self.output.add_module('upsample', nn.Upsample(self.query_strides, mode='bilinear', align_corners=False)) + self.output.add_module('upsample', nn.Upsample(scale_factor=self.query_strides, mode='bilinear', align_corners=False)) self.output.add_module('proj', create_conv2d( self.value_dim * self.num_heads, dim_out,