mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix #262, num_classes arg mixup. Make vision_transformers a bit closer to other models wrt get/reset classfier/forward_features. Fix torchscript for ViT.
This commit is contained in:
parent
da1b90e5c9
commit
f944242cb0
@ -107,7 +107,8 @@ class Attention(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
@ -204,6 +205,9 @@ class VisionTransformer(nn.Module):
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
||||
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
if hybrid_backbone is not None:
|
||||
self.patch_embed = HybridEmbed(
|
||||
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
@ -229,7 +233,7 @@ class VisionTransformer(nn.Module):
|
||||
#self.repr_act = nn.Tanh()
|
||||
|
||||
# Classifier head
|
||||
self.head = nn.Linear(embed_dim, num_classes)
|
||||
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
@ -244,11 +248,18 @@ class VisionTransformer(nn.Module):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@property
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
def forward(self, x):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
|
||||
@ -261,7 +272,11 @@ class VisionTransformer(nn.Module):
|
||||
x = blk(x)
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.head(x[:, 0])
|
||||
return x[:, 0]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -284,7 +299,7 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
|
||||
model.default_cfg = default_cfgs['vit_small_patch16_224']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
|
||||
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
|
||||
return model
|
||||
|
||||
|
||||
@ -297,7 +312,7 @@ def vit_base_patch16_224(pretrained=False, **kwargs):
|
||||
model.default_cfg = default_cfgs['vit_base_patch16_224']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
|
||||
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
|
||||
return model
|
||||
|
||||
|
||||
@ -308,8 +323,7 @@ def vit_base_patch16_384(pretrained=False, **kwargs):
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_base_patch16_384']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@ -320,8 +334,7 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_base_patch32_384']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@ -339,8 +352,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_large_patch16_384']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@ -351,8 +363,7 @@ def vit_large_patch32_384(pretrained=False, **kwargs):
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_large_patch32_384']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -43,7 +43,7 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
|
||||
if weight_decay and filter_bias_and_bn:
|
||||
skip = {}
|
||||
if hasattr(model, 'no_weight_decay'):
|
||||
skip = model.no_weight_decay
|
||||
skip = model.no_weight_decay()
|
||||
parameters = add_weight_decay(model, weight_decay, skip)
|
||||
weight_decay = 0.
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user