diff --git a/mmcls/models/backbones/swin_transformer.py b/mmcls/models/backbones/swin_transformer.py index 0ab82f19..962d41d6 100644 --- a/mmcls/models/backbones/swin_transformer.py +++ b/mmcls/models/backbones/swin_transformer.py @@ -183,11 +183,11 @@ class SwinBlockSequence(BaseModule): else: self.downsample = None - def forward(self, x, in_shape): + def forward(self, x, in_shape, do_downsample=True): for block in self.blocks: x = block(x, in_shape) - if self.downsample: + if self.downsample is not None and do_downsample: x, out_shape = self.downsample(x, in_shape) else: out_shape = in_shape @@ -232,6 +232,8 @@ class SwinTransformer(BaseBackbone): window_size (int): The height and width of the window. Defaults to 7. drop_rate (float): Dropout rate after embedding. Defaults to 0. drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + out_after_downsample (bool): Whether to output the feature map of a + stage after the following downsample layer. Defaults to False. use_abs_pos_embed (bool): If True, add absolute position embedding to the patch embedding. Defaults to False. interpolate_mode (str): Select the interpolate mode for absolute @@ -301,6 +303,7 @@ class SwinTransformer(BaseBackbone): drop_rate=0., drop_path_rate=0.1, out_indices=(3, ), + out_after_downsample=False, use_abs_pos_embed=False, interpolate_mode='bicubic', with_cp=False, @@ -329,6 +332,7 @@ class SwinTransformer(BaseBackbone): self.num_heads = self.arch_settings['num_heads'] self.num_layers = len(self.depths) self.out_indices = out_indices + self.out_after_downsample = out_after_downsample self.use_abs_pos_embed = use_abs_pos_embed self.interpolate_mode = interpolate_mode self.frozen_stages = frozen_stages @@ -392,9 +396,15 @@ class SwinTransformer(BaseBackbone): dpr = dpr[depth:] embed_dims.append(stage.out_channels) + if self.out_after_downsample: + self.num_features = embed_dims[1:] + else: + self.num_features = embed_dims[:-1] + for i in out_indices: if norm_cfg is not None: - norm_layer = build_norm_layer(norm_cfg, embed_dims[i + 1])[1] + norm_layer = build_norm_layer(norm_cfg, + self.num_features[i])[1] else: norm_layer = nn.Identity() @@ -421,14 +431,17 @@ class SwinTransformer(BaseBackbone): outs = [] for i, stage in enumerate(self.stages): - x, hw_shape = stage(x, hw_shape) + x, hw_shape = stage( + x, hw_shape, do_downsample=self.out_after_downsample) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') out = norm_layer(x) out = out.view(-1, *hw_shape, - stage.out_channels).permute(0, 3, 1, - 2).contiguous() + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() outs.append(out) + if stage.downsample is not None and not self.out_after_downsample: + x, hw_shape = stage.downsample(x, hw_shape) return tuple(outs) diff --git a/tests/test_models/test_backbones/test_swin_transformer.py b/tests/test_models/test_backbones/test_swin_transformer.py index 90d7db71..33947304 100644 --- a/tests/test_models/test_backbones/test_swin_transformer.py +++ b/tests/test_models/test_backbones/test_swin_transformer.py @@ -167,7 +167,7 @@ class TestSwinTransformer(TestCase): outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 4) - for stride, out in zip([2, 4, 8, 8], outs): + for stride, out in zip([1, 2, 4, 8], outs): self.assertEqual(out.shape, (1, 128 * stride, 56 // stride, 56 // stride))