mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Implement absolute+window pos embed for hiera, resizable but needs new weights
This commit is contained in:
parent
392b78aee7
commit
3a8a965891
@ -476,8 +476,11 @@ class Hiera(nn.Module):
|
||||
drop_path_rate: float = 0.0,
|
||||
norm_layer: Union[str, nn.Module] = "LayerNorm",
|
||||
drop_rate: float = 0.0,
|
||||
patch_drop_rate: float = 0.0,
|
||||
head_init_scale: float = 0.001,
|
||||
sep_pos_embed: bool = False,
|
||||
abs_win_pos_embed: bool = False,
|
||||
abs_pos_size: Tuple[int, int] = (14, 14),
|
||||
):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
@ -494,6 +497,7 @@ class Hiera(nn.Module):
|
||||
self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size
|
||||
self.mask_spatial_shape = [i // s for i, s in zip(self.tokens_spatial_shape, self.mask_unit_size)]
|
||||
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
|
||||
self.patch_drop_rate = patch_drop_rate
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_chans,
|
||||
@ -504,8 +508,12 @@ class Hiera(nn.Module):
|
||||
#reshape=False, # leave spatial / temporal dims in output
|
||||
)
|
||||
|
||||
self.pos_embed: Optional[nn.Parameter] = None
|
||||
self.pos_embed_abs: Optional[nn.Parameter] = None
|
||||
self.pos_embed_win: Optional[nn.Parameter] = None
|
||||
self.pos_embed_spatial: Optional[nn.Parameter] = None
|
||||
self.pos_embed_temporal: Optional[nn.Parameter] = None
|
||||
if sep_pos_embed:
|
||||
self.pos_embed = None
|
||||
self.pos_embed_spatial = nn.Parameter(
|
||||
torch.zeros(1, self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2], embed_dim)
|
||||
)
|
||||
@ -513,9 +521,12 @@ class Hiera(nn.Module):
|
||||
torch.zeros(1, self.tokens_spatial_shape[0], embed_dim)
|
||||
)
|
||||
else:
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim))
|
||||
self.pos_embed_spatial = None
|
||||
self.pos_embed_temporal = None
|
||||
if abs_win_pos_embed:
|
||||
# absolute win, params NCHW to make tile & interpolate more natural before add & reshape
|
||||
self.pos_embed_abs = nn.Parameter(torch.zeros(1, embed_dim, *abs_pos_size))
|
||||
self.pos_embed_win = nn.Parameter(torch.zeros(1, embed_dim, *mask_unit_size))
|
||||
else:
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim))
|
||||
|
||||
# Setup roll and reroll modules
|
||||
self.unroll = Unroll(
|
||||
@ -584,7 +595,11 @@ class Hiera(nn.Module):
|
||||
nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02)
|
||||
nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02)
|
||||
else:
|
||||
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
||||
if self.pos_embed is not None:
|
||||
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
||||
elif self.pos_embed_abs is not None:
|
||||
nn.init.trunc_normal_(self.pos_embed_abs, std=0.02)
|
||||
nn.init.trunc_normal_(self.pos_embed_win, std=0.02)
|
||||
self.apply(partial(self._init_weights))
|
||||
if isinstance(self.head.fc, nn.Linear):
|
||||
self.head.fc.weight.data.mul_(head_init_scale)
|
||||
@ -603,13 +618,15 @@ class Hiera(nn.Module):
|
||||
def no_weight_decay(self):
|
||||
if self.pos_embed is not None:
|
||||
return ["pos_embed"]
|
||||
elif self.pos_embed_abs is not None:
|
||||
return ['pos_embed_abs', 'pos_embed_win']
|
||||
else:
|
||||
return ["pos_embed_spatial", "pos_embed_temporal"]
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse: bool = False) -> Dict:
|
||||
return dict(
|
||||
stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|patch_embed', # stem and embed
|
||||
stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|pos_embed_abs|pos_embed_win|patch_embed',
|
||||
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
|
||||
)
|
||||
|
||||
@ -652,6 +669,18 @@ class Hiera(nn.Module):
|
||||
def _pos_embed(self, x) -> torch.Tensor:
|
||||
if self.pos_embed is not None:
|
||||
pos_embed = self.pos_embed
|
||||
elif self.pos_embed_abs is not None:
|
||||
# absolute win position embedding, from
|
||||
# Window Attention is Bugged: How not to Interpolate Position Embeddings (https://arxiv.org/abs/2311.05613)
|
||||
pos_embed_win = self.pos_embed_win.tile(self.mask_spatial_shape)
|
||||
pos_embed_abs = F.interpolate(
|
||||
self.pos_embed_abs,
|
||||
size=pos_embed_win.shape[-2:],
|
||||
mode='bicubic',
|
||||
antialias=True,
|
||||
)
|
||||
pos_embed = pos_embed_abs + pos_embed_win
|
||||
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
||||
else:
|
||||
pos_embed = (
|
||||
self.pos_embed_spatial.repeat(1, self.tokens_spatial_shape[0], 1)
|
||||
@ -735,7 +764,6 @@ class Hiera(nn.Module):
|
||||
self.head.reset(0, other=True)
|
||||
return take_indices
|
||||
|
||||
|
||||
def forward_features(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -746,6 +774,11 @@ class Hiera(nn.Module):
|
||||
mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim.
|
||||
Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch.
|
||||
"""
|
||||
if self.training and self.patch_drop_rate > 0:
|
||||
# using mask for something like 'patch dropout' via mask-units in supervised train / fine-tune
|
||||
assert mask is None
|
||||
mask = self.get_random_mask(x, mask_ratio=self.patch_drop_rate)
|
||||
|
||||
if mask is not None:
|
||||
patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
|
||||
else:
|
||||
@ -801,6 +834,7 @@ def _cfg(url='', **kwargs):
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
"hiera_tiny_224.mae_in1k_ft_in1k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
@ -863,21 +897,20 @@ default_cfgs = generate_default_cfgs({
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model=None):
|
||||
state_dict = state_dict.get('model_state', state_dict)
|
||||
output = {}
|
||||
for k, v in state_dict.items():
|
||||
if k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
|
||||
# # To resize pos embedding when using model at different size from pretrained weights
|
||||
# from timm.layers import resample_abs_pos_embed
|
||||
# v = resample_abs_pos_embed(
|
||||
# v,
|
||||
# new_size=(64, 64),
|
||||
# num_prefix_tokens=0,
|
||||
# verbose=True,
|
||||
# )
|
||||
#v = F.interpolate(v.transpose(1, 2), (model.pos_embed.shape[1],)).transpose(1, 2)
|
||||
pass
|
||||
# if k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
|
||||
# # To resize pos embedding when using model at different size from pretrained weights
|
||||
# from timm.layers import resample_abs_pos_embed
|
||||
# v = resample_abs_pos_embed(
|
||||
# v,
|
||||
# new_size=(64, 64),
|
||||
# num_prefix_tokens=0,
|
||||
# verbose=True,
|
||||
# )
|
||||
if 'head.projection.' in k:
|
||||
k = k.replace('head.projection.', 'head.fc.')
|
||||
if k.startswith('encoder_norm.'):
|
||||
|
Loading…
x
Reference in New Issue
Block a user