From 82cc53237e72ac69b83ee051940161df60df578e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 30 Sep 2023 16:03:01 -0700 Subject: [PATCH] Working on support for siglip (w/ attn pool) vit backbone, and adding registers (reg tokens) --- timm/models/_builder.py | 6 +- timm/models/vision_transformer.py | 200 ++++++++++++++++++++++++++++-- 2 files changed, 192 insertions(+), 14 deletions(-) diff --git a/timm/models/_builder.py b/timm/models/_builder.py index e1c8f419..1ea574e9 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -160,7 +160,11 @@ def load_pretrained( state_dict = pretrained_loc # pretrained_loc is the actual state dict for this override elif load_from == 'file': _logger.info(f'Loading pretrained weights from file ({pretrained_loc})') - state_dict = load_state_dict(pretrained_loc) + if pretrained_cfg.get('custom_load', False): + model.load_pretrained(pretrained_loc) + return + else: + state_dict = load_state_dict(pretrained_loc) elif load_from == 'url': _logger.info(f'Loading pretrained weights from url ({pretrained_loc})') if pretrained_cfg.get('custom_load', False): diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 10b9296b..99af178c 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -377,6 +377,72 @@ class ParallelThingsBlock(nn.Module): return self._forward(x) +class AttentionPoolLatent(nn.Module): + """ Attention pooling w/ latent query + """ + def __init__( + self, + in_features: int, + out_features: int = None, + embed_dim: int = None, + num_heads: int = 8, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + latent_size: int = 1, + latent_dim: int = None, + pos_embed: str = '', + pool_type: str = 'token', + norm_layer: Optional[nn.Module] = None, + drop: float = 0.0, + ): + super().__init__() + embed_dim = embed_dim or in_features + out_features = out_features or in_features + assert embed_dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim ** -0.5 + self.pool = pool_type + + if pos_embed == 'abs': + spatial_len = self.feat_size + self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features)) + else: + self.pos_embed = None + + self.latent_dim = latent_dim or embed_dim + latent_size = latent_size or self.feat_size + self.latent_len = latent_size + self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim)) + + self.attn = Attention(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm) + self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity() + self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio)) + + def init_weights(self): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) + + def forward(self, x): + B, N, _ = x.shape + + if self.pos_embed is not None: + # FIXME interpolate + x = x + self.pos_embed.unsqueeze(0).to(x.dtype) + + latent_q = self.latent.expand(B, -1, -1) + x = self.attn(torch.cat([latent_q, x], dim=1)) + x = x + self.mlp(self.norm(x)) + + # optional pool if latent seq_len > 1 and pooled output is desired + if self.pool == 'token': + x = x[:, 0] + elif self.pool == 'avg': + x = x.mean(1) + return x + + class VisionTransformer(nn.Module): """ Vision Transformer @@ -401,8 +467,10 @@ class VisionTransformer(nn.Module): init_values: Optional[float] = None, class_token: bool = True, no_embed_class: bool = False, + reg_tokens: int = 0, pre_norm: bool = False, fc_norm: Optional[bool] = None, + use_attn_pool: bool = False, dynamic_img_size: bool = False, dynamic_img_pad: bool = False, drop_rate: float = 0., @@ -432,6 +500,8 @@ class VisionTransformer(nn.Module): qkv_bias: Enable bias for qkv projections if True. init_values: Layer-scale init values (layer-scale enabled if not None). class_token: Use class token. + no_embed_class: Don't include position embeddings for class (or reg) tokens. + reg_tokens: Number of register tokens. fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. drop_rate: Head dropout rate. pos_drop_rate: Position embedding dropout rate. @@ -445,7 +515,7 @@ class VisionTransformer(nn.Module): """ super().__init__() assert global_pool in ('', 'avg', 'token') - assert class_token or global_pool != 'token' + assert class_token or use_attn_pool or global_pool != 'token' use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU @@ -454,7 +524,10 @@ class VisionTransformer(nn.Module): self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_prefix_tokens = 1 if class_token else 0 - self.no_embed_class = no_embed_class + self.num_prefix_tokens += reg_tokens + self.num_reg_tokens = reg_tokens + self.has_class_token = class_token + self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg) self.dynamic_img_size = dynamic_img_size self.grad_checkpointing = False @@ -474,6 +547,7 @@ class VisionTransformer(nn.Module): num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) self.pos_drop = nn.Dropout(p=pos_drop_rate) @@ -506,6 +580,14 @@ class VisionTransformer(nn.Module): self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() # Classifier Head + if use_attn_pool == 'map': + self.attn_pool = AttentionPoolLatent( + self.embed_dim, + num_heads=num_heads, + norm_layer=norm_layer, + ) + else: + self.attn_pool = None self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() @@ -566,18 +648,26 @@ class VisionTransformer(nn.Module): x = x.view(B, -1, C) else: pos_embed = self.pos_embed + + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) + if self.reg_token is not None: + to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) + if self.no_embed_class: # deit-3, updated JAX (big vision) # position embedding does not overlap with class token, add then concat x = x + pos_embed - if self.cls_token is not None: - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if to_cat: + x = torch.cat(to_cat + [x], dim=1) else: # original timm, JAX, and deit vit impl # pos_embed has entry for class token, concat then add - if self.cls_token is not None: - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if to_cat: + x = torch.cat(to_cat + [x], dim=1) x = x + pos_embed + return self.pos_drop(x) def _intermediate_layers( @@ -605,7 +695,7 @@ class VisionTransformer(nn.Module): x: torch.Tensor, n: Union[int, Sequence] = 1, reshape: bool = False, - return_class_token: bool = False, + return_prefix_tokens: bool = False, norm: bool = False, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: """ Intermediate layer accessor (NOTE: This is a WIP experiment). @@ -615,7 +705,7 @@ class VisionTransformer(nn.Module): outputs = self._intermediate_layers(x, n) if norm: outputs = [self.norm(out) for out in outputs] - class_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs] + prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs] outputs = [out[:, self.num_prefix_tokens:] for out in outputs] if reshape: @@ -625,8 +715,8 @@ class VisionTransformer(nn.Module): for out in outputs ] - if return_class_token: - return tuple(zip(outputs, class_tokens)) + if return_prefix_tokens: + return tuple(zip(outputs, prefix_tokens)) return tuple(outputs) def forward_features(self, x): @@ -642,8 +732,12 @@ class VisionTransformer(nn.Module): return x def forward_head(self, x, pre_logits: bool = False): - if self.global_pool: - x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + if self.attn_pool is not None: + x = self.attn_pool(x) + elif self.global_pool == 'avg': + x = x[:, self.num_prefix_tokens:].mean(dim=1) + elif self.global_pool: + x = x[:, 0] # class token x = self.fc_norm(x) x = self.head_drop(x) return x if pre_logits else self.head(x) @@ -767,6 +861,9 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = elif 'params/embedding/kernel' in w: prefix = 'params/' big_vision = True + elif 'params/img/embedding/kernel' in w: + prefix = 'params/img/' + big_vision = True if hasattr(model.patch_embed, 'backbone'): # hybrid @@ -823,13 +920,31 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = model.pos_embed.copy_(pos_embed_w) model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) - if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: + if (isinstance(model.head, nn.Linear) and + f'{prefix}head/bias' in w and + model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]): model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + if model.attn_pool is not None: + block_prefix = f'{prefix}MAPHead_0/' + mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' + model.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) + model.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + model.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + model.attn_pool.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + model.attn_pool.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + model.attn_pool.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + model.attn_pool.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(model.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) + getattr(model.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) + mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2) for i, block in enumerate(model.blocks.children()): block_prefix = f'{prefix}Transformer/encoderblock_{i}/' @@ -1493,6 +1608,12 @@ default_cfgs = generate_default_cfgs({ # hf_hub_id='timm/', license='cc-by-nc-4.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + + 'vit_base_patch16_siglip_224': _cfg( + file='/data/n/temp/siglip/webli_en_b16_224_63724782.npz', + custom_load=True, + # hf_hub_id='timm/', + num_classes=0), }) @@ -2119,6 +2240,59 @@ def vit_gigantic_patch16_224_ijepa(pretrained=False, **kwargs) -> VisionTransfor return model +@register_model +def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, use_attn_pool=True, + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True, no_embed_class=True, + reg_tokens=8, + ) + model = _create_vision_transformer( + 'vit_base_patch16_reg8_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=8, + class_token=True, no_embed_class=True, reg_tokens=8, + ) + model = _create_vision_transformer( + 'vit_base_patch16_reg8_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, + class_token=True, no_embed_class=True, reg_tokens=8, + ) + model = _create_vision_transformer( + 'vit_base_patch16_reg8_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_reg8_gap_224(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=8, global_pool='avg', reg_tokens=8, + ) + model = _create_vision_transformer( + 'vit_base_patch16_reg8_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + register_model_deprecations(__name__, { 'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k', 'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',