From 5f47518f2744473d18bad27191e5c5dba77fd1dc Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 21 Mar 2022 11:12:14 -0700 Subject: [PATCH] Fix pit implementation to be clsoer to deit/levit re distillation head handling --- timm/models/deit.py | 2 +- timm/models/levit.py | 2 +- timm/models/pit.py | 36 ++++++++++++++++++++++++++++-------- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/timm/models/deit.py b/timm/models/deit.py index 1251c373..e6b4b025 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -72,7 +72,7 @@ class VisionTransformerDistilled(VisionTransformer): self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim)) self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() - self.distilled_training = False + self.distilled_training = False # must set this True to train w/ distillation token self.init_weights(weight_init) diff --git a/timm/models/levit.py b/timm/models/levit.py index e93662ae..cea9f0fc 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -539,7 +539,7 @@ class LevitDistilled(Levit): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.head_dist = NormLinear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity() - self.distilled_training = False + self.distilled_training = False # must set this True to train w/ distillation token @torch.jit.ignore def get_classifier(self): diff --git a/timm/models/pit.py b/timm/models/pit.py index bef625ab..0f571319 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -147,9 +147,10 @@ class PoolingVisionTransformer(nn.Module): A PyTorch implement of 'Rethinking Spatial Dimensions of Vision Transformers' - https://arxiv.org/abs/2103.16302 """ - def __init__(self, img_size, patch_size, stride, base_dims, depth, heads, - mlp_ratio, num_classes=1000, in_chans=3, distilled=False, global_pool='token', - attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0): + def __init__( + self, img_size, patch_size, stride, base_dims, depth, heads, + mlp_ratio, num_classes=1000, in_chans=3, global_pool='token', + distilled=False, attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0): super(PoolingVisionTransformer, self).__init__() assert global_pool in ('token',) @@ -193,6 +194,7 @@ class PoolingVisionTransformer(nn.Module): self.head_dist = None if distilled: self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + self.distilled_training = False # must set this True to train w/ distillation token trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) @@ -207,6 +209,10 @@ class PoolingVisionTransformer(nn.Module): def no_weight_decay(self): return {'pos_embed', 'cls_token'} + @torch.jit.ignore + def set_distilled_training(self, enable=True): + self.distilled_training = enable + @torch.jit.ignore def set_grad_checkpointing(self, enable=True): assert not enable, 'gradient checkpointing not supported' @@ -231,16 +237,30 @@ class PoolingVisionTransformer(nn.Module): cls_tokens = self.norm(cls_tokens) return cls_tokens - def forward(self, x): - x = self.forward_features(x) + def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor: if self.head_dist is not None: - x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1]) # x must be a tuple - if self.training and not torch.jit.is_scripting(): + assert self.global_pool == 'token' + x, x_dist = x[:, 0], x[:, 1] + if not pre_logits: + x = self.head(x) + x_dist = self.head_dist(x_dist) + if self.distilled_training and self.training and not torch.jit.is_scripting(): + # only return separate classification predictions when training in distilled mode return x, x_dist else: + # during standard train / finetune, inference average the classifier predictions return (x + x_dist) / 2 else: - return self.head(x[:, 0]) + if self.global_pool == 'token': + x = x[:, 0] + if not pre_logits: + x = self.head(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x def checkpoint_filter_fn(state_dict, model):