[Fix] Fix the output position of Swin-Transformer. (#947)
* [Fix] Fix the output position of Swin-Transformer. * Rename `downsample` argument to `do_downsample`.pull/881/head
parent
6ec38fe742
commit
b5bb86a357
|
@ -183,11 +183,11 @@ class SwinBlockSequence(BaseModule):
|
|||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x, in_shape):
|
||||
def forward(self, x, in_shape, do_downsample=True):
|
||||
for block in self.blocks:
|
||||
x = block(x, in_shape)
|
||||
|
||||
if self.downsample:
|
||||
if self.downsample is not None and do_downsample:
|
||||
x, out_shape = self.downsample(x, in_shape)
|
||||
else:
|
||||
out_shape = in_shape
|
||||
|
@ -232,6 +232,8 @@ class SwinTransformer(BaseBackbone):
|
|||
window_size (int): The height and width of the window. Defaults to 7.
|
||||
drop_rate (float): Dropout rate after embedding. Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
|
||||
out_after_downsample (bool): Whether to output the feature map of a
|
||||
stage after the following downsample layer. Defaults to False.
|
||||
use_abs_pos_embed (bool): If True, add absolute position embedding to
|
||||
the patch embedding. Defaults to False.
|
||||
interpolate_mode (str): Select the interpolate mode for absolute
|
||||
|
@ -301,6 +303,7 @@ class SwinTransformer(BaseBackbone):
|
|||
drop_rate=0.,
|
||||
drop_path_rate=0.1,
|
||||
out_indices=(3, ),
|
||||
out_after_downsample=False,
|
||||
use_abs_pos_embed=False,
|
||||
interpolate_mode='bicubic',
|
||||
with_cp=False,
|
||||
|
@ -329,6 +332,7 @@ class SwinTransformer(BaseBackbone):
|
|||
self.num_heads = self.arch_settings['num_heads']
|
||||
self.num_layers = len(self.depths)
|
||||
self.out_indices = out_indices
|
||||
self.out_after_downsample = out_after_downsample
|
||||
self.use_abs_pos_embed = use_abs_pos_embed
|
||||
self.interpolate_mode = interpolate_mode
|
||||
self.frozen_stages = frozen_stages
|
||||
|
@ -392,9 +396,15 @@ class SwinTransformer(BaseBackbone):
|
|||
dpr = dpr[depth:]
|
||||
embed_dims.append(stage.out_channels)
|
||||
|
||||
if self.out_after_downsample:
|
||||
self.num_features = embed_dims[1:]
|
||||
else:
|
||||
self.num_features = embed_dims[:-1]
|
||||
|
||||
for i in out_indices:
|
||||
if norm_cfg is not None:
|
||||
norm_layer = build_norm_layer(norm_cfg, embed_dims[i + 1])[1]
|
||||
norm_layer = build_norm_layer(norm_cfg,
|
||||
self.num_features[i])[1]
|
||||
else:
|
||||
norm_layer = nn.Identity()
|
||||
|
||||
|
@ -421,14 +431,17 @@ class SwinTransformer(BaseBackbone):
|
|||
|
||||
outs = []
|
||||
for i, stage in enumerate(self.stages):
|
||||
x, hw_shape = stage(x, hw_shape)
|
||||
x, hw_shape = stage(
|
||||
x, hw_shape, do_downsample=self.out_after_downsample)
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
out = norm_layer(x)
|
||||
out = out.view(-1, *hw_shape,
|
||||
stage.out_channels).permute(0, 3, 1,
|
||||
2).contiguous()
|
||||
self.num_features[i]).permute(0, 3, 1,
|
||||
2).contiguous()
|
||||
outs.append(out)
|
||||
if stage.downsample is not None and not self.out_after_downsample:
|
||||
x, hw_shape = stage.downsample(x, hw_shape)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
|
|
|
@ -167,7 +167,7 @@ class TestSwinTransformer(TestCase):
|
|||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 4)
|
||||
for stride, out in zip([2, 4, 8, 8], outs):
|
||||
for stride, out in zip([1, 2, 4, 8], outs):
|
||||
self.assertEqual(out.shape,
|
||||
(1, 128 * stride, 56 // stride, 56 // stride))
|
||||
|
||||
|
|
Loading…
Reference in New Issue