From 82cc53237e72ac69b83ee051940161df60df578e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 30 Sep 2023 16:03:01 -0700 Subject: [PATCH 1/8] Working on support for siglip (w/ attn pool) vit backbone, and adding registers (reg tokens) --- timm/models/_builder.py | 6 +- timm/models/vision_transformer.py | 200 ++++++++++++++++++++++++++++-- 2 files changed, 192 insertions(+), 14 deletions(-) diff --git a/timm/models/_builder.py b/timm/models/_builder.py index e1c8f419..1ea574e9 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -160,7 +160,11 @@ def load_pretrained( state_dict = pretrained_loc # pretrained_loc is the actual state dict for this override elif load_from == 'file': _logger.info(f'Loading pretrained weights from file ({pretrained_loc})') - state_dict = load_state_dict(pretrained_loc) + if pretrained_cfg.get('custom_load', False): + model.load_pretrained(pretrained_loc) + return + else: + state_dict = load_state_dict(pretrained_loc) elif load_from == 'url': _logger.info(f'Loading pretrained weights from url ({pretrained_loc})') if pretrained_cfg.get('custom_load', False): diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 10b9296b..99af178c 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -377,6 +377,72 @@ class ParallelThingsBlock(nn.Module): return self._forward(x) +class AttentionPoolLatent(nn.Module): + """ Attention pooling w/ latent query + """ + def __init__( + self, + in_features: int, + out_features: int = None, + embed_dim: int = None, + num_heads: int = 8, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + latent_size: int = 1, + latent_dim: int = None, + pos_embed: str = '', + pool_type: str = 'token', + norm_layer: Optional[nn.Module] = None, + drop: float = 0.0, + ): + super().__init__() + embed_dim = embed_dim or in_features + out_features = out_features or in_features + assert embed_dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim ** -0.5 + self.pool = pool_type + + if pos_embed == 'abs': + spatial_len = self.feat_size + self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features)) + else: + self.pos_embed = None + + self.latent_dim = latent_dim or embed_dim + latent_size = latent_size or self.feat_size + self.latent_len = latent_size + self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim)) + + self.attn = Attention(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm) + self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity() + self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio)) + + def init_weights(self): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) + + def forward(self, x): + B, N, _ = x.shape + + if self.pos_embed is not None: + # FIXME interpolate + x = x + self.pos_embed.unsqueeze(0).to(x.dtype) + + latent_q = self.latent.expand(B, -1, -1) + x = self.attn(torch.cat([latent_q, x], dim=1)) + x = x + self.mlp(self.norm(x)) + + # optional pool if latent seq_len > 1 and pooled output is desired + if self.pool == 'token': + x = x[:, 0] + elif self.pool == 'avg': + x = x.mean(1) + return x + + class VisionTransformer(nn.Module): """ Vision Transformer @@ -401,8 +467,10 @@ class VisionTransformer(nn.Module): init_values: Optional[float] = None, class_token: bool = True, no_embed_class: bool = False, + reg_tokens: int = 0, pre_norm: bool = False, fc_norm: Optional[bool] = None, + use_attn_pool: bool = False, dynamic_img_size: bool = False, dynamic_img_pad: bool = False, drop_rate: float = 0., @@ -432,6 +500,8 @@ class VisionTransformer(nn.Module): qkv_bias: Enable bias for qkv projections if True. init_values: Layer-scale init values (layer-scale enabled if not None). class_token: Use class token. + no_embed_class: Don't include position embeddings for class (or reg) tokens. + reg_tokens: Number of register tokens. fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. drop_rate: Head dropout rate. pos_drop_rate: Position embedding dropout rate. @@ -445,7 +515,7 @@ class VisionTransformer(nn.Module): """ super().__init__() assert global_pool in ('', 'avg', 'token') - assert class_token or global_pool != 'token' + assert class_token or use_attn_pool or global_pool != 'token' use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU @@ -454,7 +524,10 @@ class VisionTransformer(nn.Module): self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_prefix_tokens = 1 if class_token else 0 - self.no_embed_class = no_embed_class + self.num_prefix_tokens += reg_tokens + self.num_reg_tokens = reg_tokens + self.has_class_token = class_token + self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg) self.dynamic_img_size = dynamic_img_size self.grad_checkpointing = False @@ -474,6 +547,7 @@ class VisionTransformer(nn.Module): num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) self.pos_drop = nn.Dropout(p=pos_drop_rate) @@ -506,6 +580,14 @@ class VisionTransformer(nn.Module): self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() # Classifier Head + if use_attn_pool == 'map': + self.attn_pool = AttentionPoolLatent( + self.embed_dim, + num_heads=num_heads, + norm_layer=norm_layer, + ) + else: + self.attn_pool = None self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() @@ -566,18 +648,26 @@ class VisionTransformer(nn.Module): x = x.view(B, -1, C) else: pos_embed = self.pos_embed + + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) + if self.reg_token is not None: + to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) + if self.no_embed_class: # deit-3, updated JAX (big vision) # position embedding does not overlap with class token, add then concat x = x + pos_embed - if self.cls_token is not None: - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if to_cat: + x = torch.cat(to_cat + [x], dim=1) else: # original timm, JAX, and deit vit impl # pos_embed has entry for class token, concat then add - if self.cls_token is not None: - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if to_cat: + x = torch.cat(to_cat + [x], dim=1) x = x + pos_embed + return self.pos_drop(x) def _intermediate_layers( @@ -605,7 +695,7 @@ class VisionTransformer(nn.Module): x: torch.Tensor, n: Union[int, Sequence] = 1, reshape: bool = False, - return_class_token: bool = False, + return_prefix_tokens: bool = False, norm: bool = False, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: """ Intermediate layer accessor (NOTE: This is a WIP experiment). @@ -615,7 +705,7 @@ class VisionTransformer(nn.Module): outputs = self._intermediate_layers(x, n) if norm: outputs = [self.norm(out) for out in outputs] - class_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs] + prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs] outputs = [out[:, self.num_prefix_tokens:] for out in outputs] if reshape: @@ -625,8 +715,8 @@ class VisionTransformer(nn.Module): for out in outputs ] - if return_class_token: - return tuple(zip(outputs, class_tokens)) + if return_prefix_tokens: + return tuple(zip(outputs, prefix_tokens)) return tuple(outputs) def forward_features(self, x): @@ -642,8 +732,12 @@ class VisionTransformer(nn.Module): return x def forward_head(self, x, pre_logits: bool = False): - if self.global_pool: - x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + if self.attn_pool is not None: + x = self.attn_pool(x) + elif self.global_pool == 'avg': + x = x[:, self.num_prefix_tokens:].mean(dim=1) + elif self.global_pool: + x = x[:, 0] # class token x = self.fc_norm(x) x = self.head_drop(x) return x if pre_logits else self.head(x) @@ -767,6 +861,9 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = elif 'params/embedding/kernel' in w: prefix = 'params/' big_vision = True + elif 'params/img/embedding/kernel' in w: + prefix = 'params/img/' + big_vision = True if hasattr(model.patch_embed, 'backbone'): # hybrid @@ -823,13 +920,31 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = model.pos_embed.copy_(pos_embed_w) model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) - if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: + if (isinstance(model.head, nn.Linear) and + f'{prefix}head/bias' in w and + model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]): model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + if model.attn_pool is not None: + block_prefix = f'{prefix}MAPHead_0/' + mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' + model.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) + model.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + model.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + model.attn_pool.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + model.attn_pool.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + model.attn_pool.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + model.attn_pool.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(model.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) + getattr(model.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) + mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2) for i, block in enumerate(model.blocks.children()): block_prefix = f'{prefix}Transformer/encoderblock_{i}/' @@ -1493,6 +1608,12 @@ default_cfgs = generate_default_cfgs({ # hf_hub_id='timm/', license='cc-by-nc-4.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + + 'vit_base_patch16_siglip_224': _cfg( + file='/data/n/temp/siglip/webli_en_b16_224_63724782.npz', + custom_load=True, + # hf_hub_id='timm/', + num_classes=0), }) @@ -2119,6 +2240,59 @@ def vit_gigantic_patch16_224_ijepa(pretrained=False, **kwargs) -> VisionTransfor return model +@register_model +def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, use_attn_pool=True, + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True, no_embed_class=True, + reg_tokens=8, + ) + model = _create_vision_transformer( + 'vit_base_patch16_reg8_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=8, + class_token=True, no_embed_class=True, reg_tokens=8, + ) + model = _create_vision_transformer( + 'vit_base_patch16_reg8_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, + class_token=True, no_embed_class=True, reg_tokens=8, + ) + model = _create_vision_transformer( + 'vit_base_patch16_reg8_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_reg8_gap_224(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=8, global_pool='avg', reg_tokens=8, + ) + model = _create_vision_transformer( + 'vit_base_patch16_reg8_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + register_model_deprecations(__name__, { 'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k', 'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k', From 99cfd6702ffd4086749b461522ab83d6a3559213 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 30 Sep 2023 16:16:21 -0700 Subject: [PATCH 2/8] Use global pool arg to select attention pooling in head --- timm/models/vision_transformer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 99af178c..1225e10a 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -470,7 +470,6 @@ class VisionTransformer(nn.Module): reg_tokens: int = 0, pre_norm: bool = False, fc_norm: Optional[bool] = None, - use_attn_pool: bool = False, dynamic_img_size: bool = False, dynamic_img_pad: bool = False, drop_rate: float = 0., @@ -514,8 +513,8 @@ class VisionTransformer(nn.Module): block_fn: Transformer block layer. """ super().__init__() - assert global_pool in ('', 'avg', 'token') - assert class_token or use_attn_pool or global_pool != 'token' + assert global_pool in ('', 'avg', 'token', 'map') + assert class_token or global_pool != 'token' use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU @@ -580,7 +579,7 @@ class VisionTransformer(nn.Module): self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() # Classifier Head - if use_attn_pool == 'map': + if global_pool == 'pool': self.attn_pool = AttentionPoolLatent( self.embed_dim, num_heads=num_heads, @@ -2243,7 +2242,7 @@ def vit_gigantic_patch16_224_ijepa(pretrained=False, **kwargs) -> VisionTransfor @register_model def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, use_attn_pool=True, + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', ) model = _create_vision_transformer( 'vit_base_patch16_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs)) From b9dde580763425afc188adb585c6395531a04cc8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 2 Oct 2023 11:44:12 -0700 Subject: [PATCH 3/8] Fixup attention pooling in siglip vit support --- timm/models/vision_transformer.py | 74 ++++++++++++++++++++++++------- 1 file changed, 58 insertions(+), 16 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 1225e10a..107628db 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -389,7 +389,7 @@ class AttentionPoolLatent(nn.Module): mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_norm: bool = False, - latent_size: int = 1, + latent_len: int = 1, latent_dim: int = None, pos_embed: str = '', pool_type: str = 'token', @@ -404,6 +404,7 @@ class AttentionPoolLatent(nn.Module): self.head_dim = embed_dim // num_heads self.scale = self.head_dim ** -0.5 self.pool = pool_type + self.fused_attn = use_fused_attn() if pos_embed == 'abs': spatial_len = self.feat_size @@ -412,11 +413,16 @@ class AttentionPoolLatent(nn.Module): self.pos_embed = None self.latent_dim = latent_dim or embed_dim - latent_size = latent_size or self.feat_size - self.latent_len = latent_size + self.latent_len = latent_len self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim)) - self.attn = Attention(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm) + self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) + self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.proj = nn.Linear(embed_dim, embed_dim) + self.proj_drop = nn.Dropout(drop) + self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity() self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio)) @@ -425,14 +431,31 @@ class AttentionPoolLatent(nn.Module): trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) def forward(self, x): - B, N, _ = x.shape + B, N, C = x.shape if self.pos_embed is not None: # FIXME interpolate x = x + self.pos_embed.unsqueeze(0).to(x.dtype) - latent_q = self.latent.expand(B, -1, -1) - x = self.attn(torch.cat([latent_q, x], dim=1)) + q_latent = self.latent.expand(B, -1, -1) + q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + k, v = kv.unbind(0) + + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + x = attn @ v + x = x.transpose(1, 2).reshape(B, self.latent_len, C) + x = self.proj(x) + x = self.proj_drop(x) + x = x + self.mlp(self.norm(x)) # optional pool if latent seq_len > 1 and pooled output is desired @@ -579,7 +602,7 @@ class VisionTransformer(nn.Module): self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() # Classifier Head - if global_pool == 'pool': + if global_pool == 'map': self.attn_pool = AttentionPoolLatent( self.embed_dim, num_heads=num_heads, @@ -932,14 +955,16 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = block_prefix = f'{prefix}MAPHead_0/' mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' model.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) + model.attn_pool.kv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) + model.attn_pool.kv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) + model.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) + model.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) + model.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + model.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) model.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) model.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) - model.attn_pool.attn.qkv.weight.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) - model.attn_pool.attn.qkv.bias.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) - model.attn_pool.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) - model.attn_pool.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) for r in range(2): getattr(model.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) getattr(model.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) @@ -956,11 +981,11 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) for r in range(2): getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) - block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) - block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) def _convert_openai_clip(state_dict, model): @@ -1613,6 +1638,12 @@ default_cfgs = generate_default_cfgs({ custom_load=True, # hf_hub_id='timm/', num_classes=0), + 'vit_base_patch16_siglip_256': _cfg( + file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz', + custom_load=True, + input_size=(3, 256, 256), + # hf_hub_id='timm/', + num_classes=0), }) @@ -2249,6 +2280,17 @@ def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer return model +@register_model +def vit_base_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + + @register_model def vit_medium_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict( From 42daa3b497007aa3b25abbf54000f26c9300e95d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 10 Oct 2023 22:15:45 -0700 Subject: [PATCH 4/8] Add full set of SigLIP models --- timm/models/vision_transformer.py | 93 +++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 107628db..ff753679 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -606,6 +606,7 @@ class VisionTransformer(nn.Module): self.attn_pool = AttentionPoolLatent( self.embed_dim, num_heads=num_heads, + mlp_ratio=mlp_ratio, norm_layer=norm_layer, ) else: @@ -1644,6 +1645,39 @@ default_cfgs = generate_default_cfgs({ input_size=(3, 256, 256), # hf_hub_id='timm/', num_classes=0), + 'vit_base_patch16_siglip_384': _cfg( + file='', + custom_load=True, + input_size=(3, 384, 384), + # hf_hub_id='timm/', + num_classes=0), + 'vit_base_patch16_siglip_512': _cfg( + file='', + custom_load=True, + input_size=(3, 512, 512), + # hf_hub_id='timm/', + num_classes=0), + 'vit_large_patch16_siglip_256': _cfg( + custom_load=True, + input_size=(3, 256, 256), + # hf_hub_id='timm/', + num_classes=0), + 'vit_large_patch16_siglip_384': _cfg( + custom_load=True, + input_size=(3, 384, 384), + # hf_hub_id='timm/', + num_classes=0), + 'vit_so400m_patch14_siglip_224': _cfg( + # file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz', + custom_load=True, + # hf_hub_id='timm/', + num_classes=0), + 'vit_so400m_patch14_siglip_384': _cfg( + #file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz', + custom_load=True, + # hf_hub_id='timm/', + input_size=(3, 384, 384), + num_classes=0), }) @@ -2290,6 +2324,65 @@ def vit_base_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransformer return model +@register_model +def vit_base_patch16_siglip_384(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_512(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_large_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_siglip_384(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_large_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_224(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_384(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + @register_model def vit_medium_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer: From 71365165a2e3d87bc65dca7d2f251a2beea08d7b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 16 Oct 2023 23:26:08 -0700 Subject: [PATCH 5/8] Add SigLIP weights --- timm/layers/attention_pool.py | 103 ++++++++++++++++++++ timm/models/_hub.py | 4 +- timm/models/vision_transformer.py | 153 +++++++----------------------- 3 files changed, 137 insertions(+), 123 deletions(-) create mode 100644 timm/layers/attention_pool.py diff --git a/timm/layers/attention_pool.py b/timm/layers/attention_pool.py new file mode 100644 index 00000000..41e404d2 --- /dev/null +++ b/timm/layers/attention_pool.py @@ -0,0 +1,103 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .config import use_fused_attn +from .mlp import Mlp +from .weight_init import trunc_normal_tf_ + + +class AttentionPoolLatent(nn.Module): + """ Attention pooling w/ latent query + """ + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + in_features: int, + out_features: int = None, + embed_dim: int = None, + num_heads: int = 8, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + latent_len: int = 1, + latent_dim: int = None, + pos_embed: str = '', + pool_type: str = 'token', + norm_layer: Optional[nn.Module] = None, + drop: float = 0.0, + ): + super().__init__() + embed_dim = embed_dim or in_features + out_features = out_features or in_features + assert embed_dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim ** -0.5 + self.pool = pool_type + self.fused_attn = use_fused_attn() + + if pos_embed == 'abs': + spatial_len = self.feat_size + self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features)) + else: + self.pos_embed = None + + self.latent_dim = latent_dim or embed_dim + self.latent_len = latent_len + self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim)) + + self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) + self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.proj = nn.Linear(embed_dim, embed_dim) + self.proj_drop = nn.Dropout(drop) + + self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity() + self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio)) + + self.init_weights() + + def init_weights(self): + if self.pos_embed is not None: + trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) + trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5) + + def forward(self, x): + B, N, C = x.shape + + if self.pos_embed is not None: + # FIXME interpolate + x = x + self.pos_embed.unsqueeze(0).to(x.dtype) + + q_latent = self.latent.expand(B, -1, -1) + q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + k, v = kv.unbind(0) + + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + x = attn @ v + x = x.transpose(1, 2).reshape(B, self.latent_len, C) + x = self.proj(x) + x = self.proj_drop(x) + + x = x + self.mlp(self.norm(x)) + + # optional pool if latent seq_len > 1 and pooled output is desired + if self.pool == 'token': + x = x[:, 0] + elif self.pool == 'avg': + x = x.mean(1) + return x \ No newline at end of file diff --git a/timm/models/_hub.py b/timm/models/_hub.py index e2152d21..720a5091 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -376,7 +376,7 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]: """ if filename == HF_WEIGHTS_NAME: yield HF_SAFE_WEIGHTS_NAME - # if filename == HF_OPEN_CLIP_WEIGHTS_NAME: # FIXME tracking safetensors yet - # yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME + if filename == HF_OPEN_CLIP_WEIGHTS_NAME: + yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME if filename not in (HF_WEIGHTS_NAME, HF_OPEN_CLIP_WEIGHTS_NAME) and filename.endswith(".bin"): yield filename[:-4] + ".safetensors" diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index ff753679..bd9158aa 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -37,8 +37,8 @@ from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \ - resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked +from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \ + trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn from ._builder import build_model_with_cfg from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -377,95 +377,6 @@ class ParallelThingsBlock(nn.Module): return self._forward(x) -class AttentionPoolLatent(nn.Module): - """ Attention pooling w/ latent query - """ - def __init__( - self, - in_features: int, - out_features: int = None, - embed_dim: int = None, - num_heads: int = 8, - mlp_ratio: float = 4.0, - qkv_bias: bool = True, - qk_norm: bool = False, - latent_len: int = 1, - latent_dim: int = None, - pos_embed: str = '', - pool_type: str = 'token', - norm_layer: Optional[nn.Module] = None, - drop: float = 0.0, - ): - super().__init__() - embed_dim = embed_dim or in_features - out_features = out_features or in_features - assert embed_dim % num_heads == 0 - self.num_heads = num_heads - self.head_dim = embed_dim // num_heads - self.scale = self.head_dim ** -0.5 - self.pool = pool_type - self.fused_attn = use_fused_attn() - - if pos_embed == 'abs': - spatial_len = self.feat_size - self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features)) - else: - self.pos_embed = None - - self.latent_dim = latent_dim or embed_dim - self.latent_len = latent_len - self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim)) - - self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) - self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.proj = nn.Linear(embed_dim, embed_dim) - self.proj_drop = nn.Dropout(drop) - - self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity() - self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio)) - - def init_weights(self): - if self.pos_embed is not None: - trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) - - def forward(self, x): - B, N, C = x.shape - - if self.pos_embed is not None: - # FIXME interpolate - x = x + self.pos_embed.unsqueeze(0).to(x.dtype) - - q_latent = self.latent.expand(B, -1, -1) - q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2) - - kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - k, v = kv.unbind(0) - - q, k = self.q_norm(q), self.k_norm(k) - - if self.fused_attn: - x = F.scaled_dot_product_attention(q, k, v) - else: - q = q * self.scale - attn = q @ k.transpose(-2, -1) - attn = attn.softmax(dim=-1) - x = attn @ v - x = x.transpose(1, 2).reshape(B, self.latent_len, C) - x = self.proj(x) - x = self.proj_drop(x) - - x = x + self.mlp(self.norm(x)) - - # optional pool if latent seq_len > 1 and pooled output is desired - if self.pool == 'token': - x = x[:, 0] - elif self.pool == 'avg': - x = x.mean(1) - return x - - class VisionTransformer(nn.Module): """ Vision Transformer @@ -1072,6 +983,12 @@ def checkpoint_filter_fn( if "encoder" in state_dict: state_dict = _convert_ijepa(state_dict, model) + if 'visual.trunk.pos_embed' in state_dict: + # convert an OpenCLIP model with timm vision encoder + prefix = 'visual.trunk.' + state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} + # FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) + for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k: O, I, H, W = model.patch_embed.proj.weight.shape @@ -1634,48 +1551,42 @@ default_cfgs = generate_default_cfgs({ license='cc-by-nc-4.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - 'vit_base_patch16_siglip_224': _cfg( - file='/data/n/temp/siglip/webli_en_b16_224_63724782.npz', - custom_load=True, - # hf_hub_id='timm/', + 'vit_base_patch16_siglip_224.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP', + hf_hub_filename='open_clip_pytorch_model.bin', num_classes=0), - 'vit_base_patch16_siglip_256': _cfg( - file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz', - custom_load=True, + 'vit_base_patch16_siglip_256.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP-256', + hf_hub_filename='open_clip_pytorch_model.bin', input_size=(3, 256, 256), - # hf_hub_id='timm/', num_classes=0), - 'vit_base_patch16_siglip_384': _cfg( - file='', - custom_load=True, + 'vit_base_patch16_siglip_384.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP-384', + hf_hub_filename='open_clip_pytorch_model.bin', input_size=(3, 384, 384), - # hf_hub_id='timm/', num_classes=0), - 'vit_base_patch16_siglip_512': _cfg( - file='', - custom_load=True, + 'vit_base_patch16_siglip_512.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP-512', + hf_hub_filename='open_clip_pytorch_model.bin', input_size=(3, 512, 512), - # hf_hub_id='timm/', num_classes=0), - 'vit_large_patch16_siglip_256': _cfg( - custom_load=True, + 'vit_large_patch16_siglip_256.webli': _cfg( + hf_hub_id='timm/ViT-L-16-SigLIP-256', + hf_hub_filename='open_clip_pytorch_model.bin', input_size=(3, 256, 256), - # hf_hub_id='timm/', num_classes=0), - 'vit_large_patch16_siglip_384': _cfg( - custom_load=True, + 'vit_large_patch16_siglip_384.webli': _cfg( + hf_hub_id='timm/ViT-L-16-SigLIP-384', + hf_hub_filename='open_clip_pytorch_model.bin', input_size=(3, 384, 384), - # hf_hub_id='timm/', num_classes=0), - 'vit_so400m_patch14_siglip_224': _cfg( - # file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz', - custom_load=True, - # hf_hub_id='timm/', + 'vit_so400m_patch14_siglip_224.webli': _cfg( + hf_hub_id='timm/ViT-SO400M-14-SigLIP', + hf_hub_filename='open_clip_pytorch_model.bin', num_classes=0), - 'vit_so400m_patch14_siglip_384': _cfg( - #file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz', - custom_load=True, - # hf_hub_id='timm/', + 'vit_so400m_patch14_siglip_384.webli': _cfg( + hf_hub_id='timm/ViT-SO400M-14-SigLIP-384', + hf_hub_filename='open_clip_pytorch_model.bin', input_size=(3, 384, 384), num_classes=0), }) From 59b622233bc57bbc62756469d2309e6c393b1aec Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 17 Oct 2023 07:16:17 -0700 Subject: [PATCH 6/8] Change ijepa names, add pretrain cfg for reg experimentts --- timm/models/vision_transformer.py | 58 ++++++++++++++----------------- 1 file changed, 27 insertions(+), 31 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index bd9158aa..c992c14f 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -1529,23 +1529,23 @@ default_cfgs = generate_default_cfgs({ license='cc-by-nc-4.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - 'vit_huge_patch14_224_ijepa.in1k': _cfg( + 'vit_huge_patch14_ijepa_224.in1k': _cfg( url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar', # hf_hub_id='timm/', license='cc-by-nc-4.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - 'vit_huge_patch14_224_ijepa.in22k': _cfg( + 'vit_huge_patch14_ijepa_224.in22k': _cfg( url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar', # hf_hub_id='timm/', license='cc-by-nc-4.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - 'vit_huge_patch16_448_ijepa.in1k': _cfg( + 'vit_huge_patch16_ijepa_448.in1k': _cfg( url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar', # hf_hub_id='timm/', license='cc-by-nc-4.0', input_size=(3, 448, 448), crop_pct=1.0, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - 'vit_gigantic_patch16_224_ijepa.in22k': _cfg( + 'vit_gigantic_patch16_ijepa_224.in22k': _cfg( url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar', # hf_hub_id='timm/', license='cc-by-nc-4.0', @@ -1589,6 +1589,12 @@ default_cfgs = generate_default_cfgs({ hf_hub_filename='open_clip_pytorch_model.bin', input_size=(3, 384, 384), num_classes=0), + + 'vit_medium_patch16_reg4_256': _cfg( + input_size=(3, 256, 256)), + 'vit_medium_patch16_reg4_gap_256': _cfg( + input_size=(3, 256, 256)), + 'vit_base_patch16_reg8_gap_256': _cfg(input_size=(3, 256, 256)), }) @@ -2185,33 +2191,33 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_huge_patch14_224_ijepa(pretrained=False, **kwargs) -> VisionTransformer: +def vit_huge_patch14_ijepa_224(pretrained=False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/14) from `I-JEPA` - https://arxiv.org/abs/2301.08243 """ model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg') model = _create_vision_transformer( - 'vit_huge_patch14_224_ijepa', pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_huge_patch14_ijepa_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vit_huge_patch16_448_ijepa(pretrained=False, **kwargs) -> VisionTransformer: +def vit_huge_patch16_ijepa_448(pretrained=False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/16) from `I-JEPA` - https://arxiv.org/abs/2301.08243 """ model_args = dict( patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', img_size=448) model = _create_vision_transformer( - 'vit_huge_patch16_448_ijepa', pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_huge_patch16_ijepa_448', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vit_gigantic_patch16_224_ijepa(pretrained=False, **kwargs) -> VisionTransformer: +def vit_gigantic_patch16_ijepa_224(pretrained=False, **kwargs) -> VisionTransformer: """ ViT-Gigantic (big-G) model (ViT-G/16) from `I-JEPA - https://arxiv.org/abs/2301.08243 """ model_args = dict(patch_size=16, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16) model = _create_vision_transformer( - 'vit_gigantic_patch16_224_ijepa', pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_gigantic_patch16_ijepa_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2296,45 +2302,35 @@ def vit_so400m_patch14_siglip_384(pretrained=False, **kwargs) -> VisionTransform @register_model -def vit_medium_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_medium_patch16_reg4_256(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict( - patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True, no_embed_class=True, - reg_tokens=8, + patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True, + no_embed_class=True, reg_tokens=4, ) model = _create_vision_transformer( - 'vit_base_patch16_reg8_224', pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vit_base_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer: - model_args = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=8, - class_token=True, no_embed_class=True, reg_tokens=8, - ) - model = _create_vision_transformer( - 'vit_base_patch16_reg8_224', pretrained=pretrained, **dict(model_args, **kwargs)) - return model - - -@register_model -def vit_medium_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_medium_patch16_reg4_gap_256(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, - class_token=True, no_embed_class=True, reg_tokens=8, + class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg', ) model = _create_vision_transformer( - 'vit_base_patch16_reg8_224', pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_medium_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vit_base_patch16_reg8_gap_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch16_reg8_gap_256(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=8, global_pool='avg', reg_tokens=8, + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, + no_embed_class=True, global_pool='avg', reg_tokens=8, ) model = _create_vision_transformer( - 'vit_base_patch16_reg8_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_base_patch16_reg8_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model From a58f9162d7ee6c981dc1ed5d2a18f9b74ee38a47 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 17 Oct 2023 09:28:21 -0700 Subject: [PATCH 7/8] Missed __init__.py update for attention pooling layer add --- timm/layers/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 5a610da6..2cdcfd98 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -1,6 +1,7 @@ from .activations import * from .adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .attention_pool import AttentionPoolLatent from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding from .blur_pool import BlurPool2d from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead From e728f3efdb1e5da816d3defa6dd5f60b2090a8ac Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 17 Oct 2023 15:44:46 -0700 Subject: [PATCH 8/8] Cleanup ijepa models, they're just gap (global-avg-pool) models w/o heads. fc-norm conversion was wrong, gigantic should have been giant --- timm/models/vision_transformer.py | 97 +++++++++++++++---------------- 1 file changed, 47 insertions(+), 50 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 2d374085..d56037c6 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -950,17 +950,6 @@ def _convert_dinov2(state_dict, model): return out_dict -def _convert_ijepa(state_dict, model): - out_dict = {} - for k, v in state_dict['encoder'].items(): - if k.startswith('module.'): - k = k[7:] - if k.startswith('norm.'): - k = 'fc_norm.' + k[5:] - out_dict[k] = v - return out_dict - - def checkpoint_filter_fn( state_dict, model, @@ -973,6 +962,7 @@ def checkpoint_filter_fn( out_dict = {} state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('state_dict', state_dict) + prefix = '' if 'visual.class_embedding' in state_dict: return _convert_openai_clip(state_dict, model) @@ -981,13 +971,17 @@ def checkpoint_filter_fn( state_dict = _convert_dinov2(state_dict, model) if "encoder" in state_dict: - state_dict = _convert_ijepa(state_dict, model) + state_dict = state_dict['encoder'] + prefix = 'module.' if 'visual.trunk.pos_embed' in state_dict: # convert an OpenCLIP model with timm vision encoder - prefix = 'visual.trunk.' - state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} # FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) + prefix = 'visual.trunk.' + + if prefix: + # filter on & remove prefix string from keys + state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k: @@ -1529,23 +1523,23 @@ default_cfgs = generate_default_cfgs({ license='cc-by-nc-4.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - 'vit_huge_patch14_ijepa_224.in1k': _cfg( + 'vit_huge_patch14_gap_224.in1k_ijepa': _cfg( url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar', # hf_hub_id='timm/', license='cc-by-nc-4.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - 'vit_huge_patch14_ijepa_224.in22k': _cfg( + 'vit_huge_patch14_gap_224.in22k_ijepa': _cfg( url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar', # hf_hub_id='timm/', license='cc-by-nc-4.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - 'vit_huge_patch16_ijepa_448.in1k': _cfg( + 'vit_huge_patch16_gap_448.in1k_ijepa': _cfg( url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar', # hf_hub_id='timm/', license='cc-by-nc-4.0', input_size=(3, 448, 448), crop_pct=1.0, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - 'vit_gigantic_patch16_ijepa_224.in22k': _cfg( + 'vit_giant_patch16_gap_224.in22k_ijepa': _cfg( url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar', # hf_hub_id='timm/', license='cc-by-nc-4.0', @@ -1856,7 +1850,7 @@ def vit_medium_patch16_gap_384(pretrained=False, **kwargs) -> VisionTransformer: @register_model def vit_base_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer: - """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256 + """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224 """ model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False) @@ -1865,6 +1859,40 @@ def vit_base_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer: return model +@register_model +def vit_huge_patch14_gap_224(pretrained=False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) w/ no class token, avg pool + """ + model_args = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False) + model = _create_vision_transformer( + 'vit_huge_patch14_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch16_gap_448(pretrained=False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/16) w/ no class token, avg pool @ 448x448 + """ + model_args = dict( + patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False) + model = _create_vision_transformer( + 'vit_huge_patch16_gap_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_giant_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer: + """ ViT-Giant (little-gg) model (ViT-g/16) w/ no class token, avg pool + """ + model_args = dict( + patch_size=16, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11, + class_token=False, global_pool='avg', fc_norm=False) + model = _create_vision_transformer( + 'vit_giant_patch16_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def vit_base_patch32_clip_224(pretrained=False, **kwargs) -> VisionTransformer: """ ViT-B/32 CLIP image tower @ 224x224 @@ -2190,37 +2218,6 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: return model -@register_model -def vit_huge_patch14_ijepa_224(pretrained=False, **kwargs) -> VisionTransformer: - """ ViT-Huge model (ViT-H/14) from `I-JEPA` - https://arxiv.org/abs/2301.08243 - """ - model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg') - model = _create_vision_transformer( - 'vit_huge_patch14_ijepa_224', pretrained=pretrained, **dict(model_args, **kwargs)) - return model - - -@register_model -def vit_huge_patch16_ijepa_448(pretrained=False, **kwargs) -> VisionTransformer: - """ ViT-Huge model (ViT-H/16) from `I-JEPA` - https://arxiv.org/abs/2301.08243 - """ - model_args = dict( - patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', img_size=448) - model = _create_vision_transformer( - 'vit_huge_patch16_ijepa_448', pretrained=pretrained, **dict(model_args, **kwargs)) - return model - - -@register_model -def vit_gigantic_patch16_ijepa_224(pretrained=False, **kwargs) -> VisionTransformer: - """ ViT-Gigantic (big-G) model (ViT-G/16) from `I-JEPA - https://arxiv.org/abs/2301.08243 - """ - model_args = dict(patch_size=16, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16) - model = _create_vision_transformer( - 'vit_gigantic_patch16_ijepa_224', pretrained=pretrained, **dict(model_args, **kwargs)) - return model - - @register_model def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict(