[Fix] Fix the output position of Swin-Transformer. (#947)

* [Fix] Fix the output position of Swin-Transformer.

* Rename `downsample` argument to `do_downsample`.
pull/881/head
Ma Zerun 2022-08-03 19:32:29 +08:00 committed by GitHub
parent 6ec38fe742
commit b5bb86a357
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 7 deletions

View File

@ -183,11 +183,11 @@ class SwinBlockSequence(BaseModule):
else:
self.downsample = None
def forward(self, x, in_shape):
def forward(self, x, in_shape, do_downsample=True):
for block in self.blocks:
x = block(x, in_shape)
if self.downsample:
if self.downsample is not None and do_downsample:
x, out_shape = self.downsample(x, in_shape)
else:
out_shape = in_shape
@ -232,6 +232,8 @@ class SwinTransformer(BaseBackbone):
window_size (int): The height and width of the window. Defaults to 7.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
out_after_downsample (bool): Whether to output the feature map of a
stage after the following downsample layer. Defaults to False.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults to False.
interpolate_mode (str): Select the interpolate mode for absolute
@ -301,6 +303,7 @@ class SwinTransformer(BaseBackbone):
drop_rate=0.,
drop_path_rate=0.1,
out_indices=(3, ),
out_after_downsample=False,
use_abs_pos_embed=False,
interpolate_mode='bicubic',
with_cp=False,
@ -329,6 +332,7 @@ class SwinTransformer(BaseBackbone):
self.num_heads = self.arch_settings['num_heads']
self.num_layers = len(self.depths)
self.out_indices = out_indices
self.out_after_downsample = out_after_downsample
self.use_abs_pos_embed = use_abs_pos_embed
self.interpolate_mode = interpolate_mode
self.frozen_stages = frozen_stages
@ -392,9 +396,15 @@ class SwinTransformer(BaseBackbone):
dpr = dpr[depth:]
embed_dims.append(stage.out_channels)
if self.out_after_downsample:
self.num_features = embed_dims[1:]
else:
self.num_features = embed_dims[:-1]
for i in out_indices:
if norm_cfg is not None:
norm_layer = build_norm_layer(norm_cfg, embed_dims[i + 1])[1]
norm_layer = build_norm_layer(norm_cfg,
self.num_features[i])[1]
else:
norm_layer = nn.Identity()
@ -421,14 +431,17 @@ class SwinTransformer(BaseBackbone):
outs = []
for i, stage in enumerate(self.stages):
x, hw_shape = stage(x, hw_shape)
x, hw_shape = stage(
x, hw_shape, do_downsample=self.out_after_downsample)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(x)
out = out.view(-1, *hw_shape,
stage.out_channels).permute(0, 3, 1,
self.num_features[i]).permute(0, 3, 1,
2).contiguous()
outs.append(out)
if stage.downsample is not None and not self.out_after_downsample:
x, hw_shape = stage.downsample(x, hw_shape)
return tuple(outs)

View File

@ -167,7 +167,7 @@ class TestSwinTransformer(TestCase):
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 4)
for stride, out in zip([2, 4, 8, 8], outs):
for stride, out in zip([1, 2, 4, 8], outs):
self.assertEqual(out.shape,
(1, 128 * stride, 56 // stride, 56 // stride))