From 3a8a965891318552b8fe4fb399e3dafad731c91d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 Jul 2024 21:43:37 -0700 Subject: [PATCH] Implement absolute+window pos embed for hiera, resizable but needs new weights --- timm/models/hiera.py | 69 ++++++++++++++++++++++++++++++++------------ 1 file changed, 51 insertions(+), 18 deletions(-) diff --git a/timm/models/hiera.py b/timm/models/hiera.py index e06d3545..f3cece6e 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -476,8 +476,11 @@ class Hiera(nn.Module): drop_path_rate: float = 0.0, norm_layer: Union[str, nn.Module] = "LayerNorm", drop_rate: float = 0.0, + patch_drop_rate: float = 0.0, head_init_scale: float = 0.001, sep_pos_embed: bool = False, + abs_win_pos_embed: bool = False, + abs_pos_size: Tuple[int, int] = (14, 14), ): super().__init__() self.num_classes = num_classes @@ -494,6 +497,7 @@ class Hiera(nn.Module): self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size self.mask_spatial_shape = [i // s for i, s in zip(self.tokens_spatial_shape, self.mask_unit_size)] self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + self.patch_drop_rate = patch_drop_rate self.patch_embed = PatchEmbed( in_chans, @@ -504,8 +508,12 @@ class Hiera(nn.Module): #reshape=False, # leave spatial / temporal dims in output ) + self.pos_embed: Optional[nn.Parameter] = None + self.pos_embed_abs: Optional[nn.Parameter] = None + self.pos_embed_win: Optional[nn.Parameter] = None + self.pos_embed_spatial: Optional[nn.Parameter] = None + self.pos_embed_temporal: Optional[nn.Parameter] = None if sep_pos_embed: - self.pos_embed = None self.pos_embed_spatial = nn.Parameter( torch.zeros(1, self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2], embed_dim) ) @@ -513,9 +521,12 @@ class Hiera(nn.Module): torch.zeros(1, self.tokens_spatial_shape[0], embed_dim) ) else: - self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim)) - self.pos_embed_spatial = None - self.pos_embed_temporal = None + if abs_win_pos_embed: + # absolute win, params NCHW to make tile & interpolate more natural before add & reshape + self.pos_embed_abs = nn.Parameter(torch.zeros(1, embed_dim, *abs_pos_size)) + self.pos_embed_win = nn.Parameter(torch.zeros(1, embed_dim, *mask_unit_size)) + else: + self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim)) # Setup roll and reroll modules self.unroll = Unroll( @@ -584,7 +595,11 @@ class Hiera(nn.Module): nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02) nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02) else: - nn.init.trunc_normal_(self.pos_embed, std=0.02) + if self.pos_embed is not None: + nn.init.trunc_normal_(self.pos_embed, std=0.02) + elif self.pos_embed_abs is not None: + nn.init.trunc_normal_(self.pos_embed_abs, std=0.02) + nn.init.trunc_normal_(self.pos_embed_win, std=0.02) self.apply(partial(self._init_weights)) if isinstance(self.head.fc, nn.Linear): self.head.fc.weight.data.mul_(head_init_scale) @@ -603,13 +618,15 @@ class Hiera(nn.Module): def no_weight_decay(self): if self.pos_embed is not None: return ["pos_embed"] + elif self.pos_embed_abs is not None: + return ['pos_embed_abs', 'pos_embed_win'] else: return ["pos_embed_spatial", "pos_embed_temporal"] @torch.jit.ignore def group_matcher(self, coarse: bool = False) -> Dict: return dict( - stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|patch_embed', # stem and embed + stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|pos_embed_abs|pos_embed_win|patch_embed', blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] ) @@ -652,6 +669,18 @@ class Hiera(nn.Module): def _pos_embed(self, x) -> torch.Tensor: if self.pos_embed is not None: pos_embed = self.pos_embed + elif self.pos_embed_abs is not None: + # absolute win position embedding, from + # Window Attention is Bugged: How not to Interpolate Position Embeddings (https://arxiv.org/abs/2311.05613) + pos_embed_win = self.pos_embed_win.tile(self.mask_spatial_shape) + pos_embed_abs = F.interpolate( + self.pos_embed_abs, + size=pos_embed_win.shape[-2:], + mode='bicubic', + antialias=True, + ) + pos_embed = pos_embed_abs + pos_embed_win + pos_embed = pos_embed.flatten(2).transpose(1, 2) else: pos_embed = ( self.pos_embed_spatial.repeat(1, self.tokens_spatial_shape[0], 1) @@ -735,7 +764,6 @@ class Hiera(nn.Module): self.head.reset(0, other=True) return take_indices - def forward_features( self, x: torch.Tensor, @@ -746,6 +774,11 @@ class Hiera(nn.Module): mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim. Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch. """ + if self.training and self.patch_drop_rate > 0: + # using mask for something like 'patch dropout' via mask-units in supervised train / fine-tune + assert mask is None + mask = self.get_random_mask(x, mask_ratio=self.patch_drop_rate) + if mask is not None: patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape else: @@ -801,6 +834,7 @@ def _cfg(url='', **kwargs): **kwargs } + default_cfgs = generate_default_cfgs({ "hiera_tiny_224.mae_in1k_ft_in1k": _cfg( hf_hub_id='timm/', @@ -863,21 +897,20 @@ default_cfgs = generate_default_cfgs({ ), }) + def checkpoint_filter_fn(state_dict, model=None): state_dict = state_dict.get('model_state', state_dict) output = {} for k, v in state_dict.items(): - if k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: - # # To resize pos embedding when using model at different size from pretrained weights - # from timm.layers import resample_abs_pos_embed - # v = resample_abs_pos_embed( - # v, - # new_size=(64, 64), - # num_prefix_tokens=0, - # verbose=True, - # ) - #v = F.interpolate(v.transpose(1, 2), (model.pos_embed.shape[1],)).transpose(1, 2) - pass + # if k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: + # # To resize pos embedding when using model at different size from pretrained weights + # from timm.layers import resample_abs_pos_embed + # v = resample_abs_pos_embed( + # v, + # new_size=(64, 64), + # num_prefix_tokens=0, + # verbose=True, + # ) if 'head.projection.' in k: k = k.replace('head.projection.', 'head.fc.') if k.startswith('encoder_norm.'):