diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index cddb68a0..9d2ad633 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -279,7 +279,7 @@ class SwinTransformerV2Block(nn.Module): else: attn_mask = None - self.register_buffer("attn_mask", attn_mask) + self.register_buffer("attn_mask", attn_mask, persistent=False) def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]: target_window_size = to_2tuple(target_window_size) @@ -616,11 +616,12 @@ def checkpoint_filter_fn(state_dict, model): out_dict = {} import re for k, v in state_dict.items(): - if any([n in k for n in ('relative_position_index', 'relative_coords_table')]): + if any([n in k for n in ('relative_position_index', 'relative_coords_table', 'attn_mask')]): continue # skip buffers that should not be persistent k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k) k = k.replace('head.', 'head.fc.') out_dict[k] = v + return out_dict