Merge pull request #1760 from MarcoForte/patch-1

skip SwinV2 attention mask buffers
onnx_export
Ross Wightman 2023-04-11 14:37:02 -07:00 committed by GitHub
commit 5fa53c5f31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 2 deletions

View File

@ -279,7 +279,7 @@ class SwinTransformerV2Block(nn.Module):
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
self.register_buffer("attn_mask", attn_mask, persistent=False)
def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
target_window_size = to_2tuple(target_window_size)
@ -616,11 +616,12 @@ def checkpoint_filter_fn(state_dict, model):
out_dict = {}
import re
for k, v in state_dict.items():
if any([n in k for n in ('relative_position_index', 'relative_coords_table')]):
if any([n in k for n in ('relative_position_index', 'relative_coords_table', 'attn_mask')]):
continue # skip buffers that should not be persistent
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
k = k.replace('head.', 'head.fc.')
out_dict[k] = v
return out_dict