Implement absolute+window pos embed for hiera, resizable but needs new weights

This commit is contained in:
Ross Wightman 2024-07-18 21:43:37 -07:00
parent 392b78aee7
commit 3a8a965891

View File

@ -476,8 +476,11 @@ class Hiera(nn.Module):
drop_path_rate: float = 0.0,
norm_layer: Union[str, nn.Module] = "LayerNorm",
drop_rate: float = 0.0,
patch_drop_rate: float = 0.0,
head_init_scale: float = 0.001,
sep_pos_embed: bool = False,
abs_win_pos_embed: bool = False,
abs_pos_size: Tuple[int, int] = (14, 14),
):
super().__init__()
self.num_classes = num_classes
@ -494,6 +497,7 @@ class Hiera(nn.Module):
self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size
self.mask_spatial_shape = [i // s for i, s in zip(self.tokens_spatial_shape, self.mask_unit_size)]
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
self.patch_drop_rate = patch_drop_rate
self.patch_embed = PatchEmbed(
in_chans,
@ -504,8 +508,12 @@ class Hiera(nn.Module):
#reshape=False, # leave spatial / temporal dims in output
)
self.pos_embed: Optional[nn.Parameter] = None
self.pos_embed_abs: Optional[nn.Parameter] = None
self.pos_embed_win: Optional[nn.Parameter] = None
self.pos_embed_spatial: Optional[nn.Parameter] = None
self.pos_embed_temporal: Optional[nn.Parameter] = None
if sep_pos_embed:
self.pos_embed = None
self.pos_embed_spatial = nn.Parameter(
torch.zeros(1, self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2], embed_dim)
)
@ -513,9 +521,12 @@ class Hiera(nn.Module):
torch.zeros(1, self.tokens_spatial_shape[0], embed_dim)
)
else:
self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim))
self.pos_embed_spatial = None
self.pos_embed_temporal = None
if abs_win_pos_embed:
# absolute win, params NCHW to make tile & interpolate more natural before add & reshape
self.pos_embed_abs = nn.Parameter(torch.zeros(1, embed_dim, *abs_pos_size))
self.pos_embed_win = nn.Parameter(torch.zeros(1, embed_dim, *mask_unit_size))
else:
self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim))
# Setup roll and reroll modules
self.unroll = Unroll(
@ -584,7 +595,11 @@ class Hiera(nn.Module):
nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02)
nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02)
else:
nn.init.trunc_normal_(self.pos_embed, std=0.02)
if self.pos_embed is not None:
nn.init.trunc_normal_(self.pos_embed, std=0.02)
elif self.pos_embed_abs is not None:
nn.init.trunc_normal_(self.pos_embed_abs, std=0.02)
nn.init.trunc_normal_(self.pos_embed_win, std=0.02)
self.apply(partial(self._init_weights))
if isinstance(self.head.fc, nn.Linear):
self.head.fc.weight.data.mul_(head_init_scale)
@ -603,13 +618,15 @@ class Hiera(nn.Module):
def no_weight_decay(self):
if self.pos_embed is not None:
return ["pos_embed"]
elif self.pos_embed_abs is not None:
return ['pos_embed_abs', 'pos_embed_win']
else:
return ["pos_embed_spatial", "pos_embed_temporal"]
@torch.jit.ignore
def group_matcher(self, coarse: bool = False) -> Dict:
return dict(
stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|patch_embed', # stem and embed
stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|pos_embed_abs|pos_embed_win|patch_embed',
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@ -652,6 +669,18 @@ class Hiera(nn.Module):
def _pos_embed(self, x) -> torch.Tensor:
if self.pos_embed is not None:
pos_embed = self.pos_embed
elif self.pos_embed_abs is not None:
# absolute win position embedding, from
# Window Attention is Bugged: How not to Interpolate Position Embeddings (https://arxiv.org/abs/2311.05613)
pos_embed_win = self.pos_embed_win.tile(self.mask_spatial_shape)
pos_embed_abs = F.interpolate(
self.pos_embed_abs,
size=pos_embed_win.shape[-2:],
mode='bicubic',
antialias=True,
)
pos_embed = pos_embed_abs + pos_embed_win
pos_embed = pos_embed.flatten(2).transpose(1, 2)
else:
pos_embed = (
self.pos_embed_spatial.repeat(1, self.tokens_spatial_shape[0], 1)
@ -735,7 +764,6 @@ class Hiera(nn.Module):
self.head.reset(0, other=True)
return take_indices
def forward_features(
self,
x: torch.Tensor,
@ -746,6 +774,11 @@ class Hiera(nn.Module):
mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim.
Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch.
"""
if self.training and self.patch_drop_rate > 0:
# using mask for something like 'patch dropout' via mask-units in supervised train / fine-tune
assert mask is None
mask = self.get_random_mask(x, mask_ratio=self.patch_drop_rate)
if mask is not None:
patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
else:
@ -801,6 +834,7 @@ def _cfg(url='', **kwargs):
**kwargs
}
default_cfgs = generate_default_cfgs({
"hiera_tiny_224.mae_in1k_ft_in1k": _cfg(
hf_hub_id='timm/',
@ -863,21 +897,20 @@ default_cfgs = generate_default_cfgs({
),
})
def checkpoint_filter_fn(state_dict, model=None):
state_dict = state_dict.get('model_state', state_dict)
output = {}
for k, v in state_dict.items():
if k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
# # To resize pos embedding when using model at different size from pretrained weights
# from timm.layers import resample_abs_pos_embed
# v = resample_abs_pos_embed(
# v,
# new_size=(64, 64),
# num_prefix_tokens=0,
# verbose=True,
# )
#v = F.interpolate(v.transpose(1, 2), (model.pos_embed.shape[1],)).transpose(1, 2)
pass
# if k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
# # To resize pos embedding when using model at different size from pretrained weights
# from timm.layers import resample_abs_pos_embed
# v = resample_abs_pos_embed(
# v,
# new_size=(64, 64),
# num_prefix_tokens=0,
# verbose=True,
# )
if 'head.projection.' in k:
k = k.replace('head.projection.', 'head.fc.')
if k.startswith('encoder_norm.'):