mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix pit implementation to be clsoer to deit/levit re distillation head handling
This commit is contained in:
parent
0862e6ebae
commit
5f47518f27
@ -72,7 +72,7 @@ class VisionTransformerDistilled(VisionTransformer):
|
||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim))
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
||||
self.distilled_training = False
|
||||
self.distilled_training = False # must set this True to train w/ distillation token
|
||||
|
||||
self.init_weights(weight_init)
|
||||
|
||||
|
@ -539,7 +539,7 @@ class LevitDistilled(Levit):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.head_dist = NormLinear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
||||
self.distilled_training = False
|
||||
self.distilled_training = False # must set this True to train w/ distillation token
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
|
@ -147,9 +147,10 @@ class PoolingVisionTransformer(nn.Module):
|
||||
A PyTorch implement of 'Rethinking Spatial Dimensions of Vision Transformers'
|
||||
- https://arxiv.org/abs/2103.16302
|
||||
"""
|
||||
def __init__(self, img_size, patch_size, stride, base_dims, depth, heads,
|
||||
mlp_ratio, num_classes=1000, in_chans=3, distilled=False, global_pool='token',
|
||||
attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
|
||||
def __init__(
|
||||
self, img_size, patch_size, stride, base_dims, depth, heads,
|
||||
mlp_ratio, num_classes=1000, in_chans=3, global_pool='token',
|
||||
distilled=False, attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
|
||||
super(PoolingVisionTransformer, self).__init__()
|
||||
assert global_pool in ('token',)
|
||||
|
||||
@ -193,6 +194,7 @@ class PoolingVisionTransformer(nn.Module):
|
||||
self.head_dist = None
|
||||
if distilled:
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.distilled_training = False # must set this True to train w/ distillation token
|
||||
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
@ -207,6 +209,10 @@ class PoolingVisionTransformer(nn.Module):
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_distilled_training(self, enable=True):
|
||||
self.distilled_training = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
assert not enable, 'gradient checkpointing not supported'
|
||||
@ -231,16 +237,30 @@ class PoolingVisionTransformer(nn.Module):
|
||||
cls_tokens = self.norm(cls_tokens)
|
||||
return cls_tokens
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
|
||||
if self.head_dist is not None:
|
||||
x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1]) # x must be a tuple
|
||||
if self.training and not torch.jit.is_scripting():
|
||||
assert self.global_pool == 'token'
|
||||
x, x_dist = x[:, 0], x[:, 1]
|
||||
if not pre_logits:
|
||||
x = self.head(x)
|
||||
x_dist = self.head_dist(x_dist)
|
||||
if self.distilled_training and self.training and not torch.jit.is_scripting():
|
||||
# only return separate classification predictions when training in distilled mode
|
||||
return x, x_dist
|
||||
else:
|
||||
# during standard train / finetune, inference average the classifier predictions
|
||||
return (x + x_dist) / 2
|
||||
else:
|
||||
return self.head(x[:, 0])
|
||||
if self.global_pool == 'token':
|
||||
x = x[:, 0]
|
||||
if not pre_logits:
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
|
Loading…
x
Reference in New Issue
Block a user