mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix pit implementation to be clsoer to deit/levit re distillation head handling
This commit is contained in:
parent
0862e6ebae
commit
5f47518f27
@ -72,7 +72,7 @@ class VisionTransformerDistilled(VisionTransformer):
|
|||||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
||||||
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim))
|
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim))
|
||||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
||||||
self.distilled_training = False
|
self.distilled_training = False # must set this True to train w/ distillation token
|
||||||
|
|
||||||
self.init_weights(weight_init)
|
self.init_weights(weight_init)
|
||||||
|
|
||||||
|
@ -539,7 +539,7 @@ class LevitDistilled(Levit):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.head_dist = NormLinear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
self.head_dist = NormLinear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
||||||
self.distilled_training = False
|
self.distilled_training = False # must set this True to train w/ distillation token
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def get_classifier(self):
|
def get_classifier(self):
|
||||||
|
@ -147,9 +147,10 @@ class PoolingVisionTransformer(nn.Module):
|
|||||||
A PyTorch implement of 'Rethinking Spatial Dimensions of Vision Transformers'
|
A PyTorch implement of 'Rethinking Spatial Dimensions of Vision Transformers'
|
||||||
- https://arxiv.org/abs/2103.16302
|
- https://arxiv.org/abs/2103.16302
|
||||||
"""
|
"""
|
||||||
def __init__(self, img_size, patch_size, stride, base_dims, depth, heads,
|
def __init__(
|
||||||
mlp_ratio, num_classes=1000, in_chans=3, distilled=False, global_pool='token',
|
self, img_size, patch_size, stride, base_dims, depth, heads,
|
||||||
attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
|
mlp_ratio, num_classes=1000, in_chans=3, global_pool='token',
|
||||||
|
distilled=False, attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
|
||||||
super(PoolingVisionTransformer, self).__init__()
|
super(PoolingVisionTransformer, self).__init__()
|
||||||
assert global_pool in ('token',)
|
assert global_pool in ('token',)
|
||||||
|
|
||||||
@ -193,6 +194,7 @@ class PoolingVisionTransformer(nn.Module):
|
|||||||
self.head_dist = None
|
self.head_dist = None
|
||||||
if distilled:
|
if distilled:
|
||||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
self.distilled_training = False # must set this True to train w/ distillation token
|
||||||
|
|
||||||
trunc_normal_(self.pos_embed, std=.02)
|
trunc_normal_(self.pos_embed, std=.02)
|
||||||
trunc_normal_(self.cls_token, std=.02)
|
trunc_normal_(self.cls_token, std=.02)
|
||||||
@ -207,6 +209,10 @@ class PoolingVisionTransformer(nn.Module):
|
|||||||
def no_weight_decay(self):
|
def no_weight_decay(self):
|
||||||
return {'pos_embed', 'cls_token'}
|
return {'pos_embed', 'cls_token'}
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_distilled_training(self, enable=True):
|
||||||
|
self.distilled_training = enable
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def set_grad_checkpointing(self, enable=True):
|
def set_grad_checkpointing(self, enable=True):
|
||||||
assert not enable, 'gradient checkpointing not supported'
|
assert not enable, 'gradient checkpointing not supported'
|
||||||
@ -231,16 +237,30 @@ class PoolingVisionTransformer(nn.Module):
|
|||||||
cls_tokens = self.norm(cls_tokens)
|
cls_tokens = self.norm(cls_tokens)
|
||||||
return cls_tokens
|
return cls_tokens
|
||||||
|
|
||||||
def forward(self, x):
|
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
|
||||||
x = self.forward_features(x)
|
|
||||||
if self.head_dist is not None:
|
if self.head_dist is not None:
|
||||||
x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1]) # x must be a tuple
|
assert self.global_pool == 'token'
|
||||||
if self.training and not torch.jit.is_scripting():
|
x, x_dist = x[:, 0], x[:, 1]
|
||||||
|
if not pre_logits:
|
||||||
|
x = self.head(x)
|
||||||
|
x_dist = self.head_dist(x_dist)
|
||||||
|
if self.distilled_training and self.training and not torch.jit.is_scripting():
|
||||||
|
# only return separate classification predictions when training in distilled mode
|
||||||
return x, x_dist
|
return x, x_dist
|
||||||
else:
|
else:
|
||||||
|
# during standard train / finetune, inference average the classifier predictions
|
||||||
return (x + x_dist) / 2
|
return (x + x_dist) / 2
|
||||||
else:
|
else:
|
||||||
return self.head(x[:, 0])
|
if self.global_pool == 'token':
|
||||||
|
x = x[:, 0]
|
||||||
|
if not pre_logits:
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.forward_features(x)
|
||||||
|
x = self.forward_head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_filter_fn(state_dict, model):
|
def checkpoint_filter_fn(state_dict, model):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user