mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix pos_embed scaling for ViT and num_classes != 1000 for pretrained distilled deit and pit models. Fix #426 and fix #433
This commit is contained in:
parent
a760a4c3f4
commit
7953e5d11a
@ -198,20 +198,24 @@ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filte
|
|||||||
_logger.warning(
|
_logger.warning(
|
||||||
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
|
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
|
||||||
|
|
||||||
classifier_name = default_cfg.get('classifier', None)
|
classifiers = default_cfg.get('classifier', None)
|
||||||
label_offset = default_cfg.get('label_offset', 0)
|
label_offset = default_cfg.get('label_offset', 0)
|
||||||
if classifier_name is not None:
|
if classifiers is not None:
|
||||||
|
if isinstance(classifiers, str):
|
||||||
|
classifiers = (classifiers,)
|
||||||
if num_classes != default_cfg['num_classes']:
|
if num_classes != default_cfg['num_classes']:
|
||||||
# completely discard fully connected if model num_classes doesn't match pretrained weights
|
for classifier_name in classifiers:
|
||||||
del state_dict[classifier_name + '.weight']
|
# completely discard fully connected if model num_classes doesn't match pretrained weights
|
||||||
del state_dict[classifier_name + '.bias']
|
del state_dict[classifier_name + '.weight']
|
||||||
|
del state_dict[classifier_name + '.bias']
|
||||||
strict = False
|
strict = False
|
||||||
elif label_offset > 0:
|
elif label_offset > 0:
|
||||||
# special case for pretrained weights with an extra background class in pretrained weights
|
for classifier_name in classifiers:
|
||||||
classifier_weight = state_dict[classifier_name + '.weight']
|
# special case for pretrained weights with an extra background class in pretrained weights
|
||||||
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
|
classifier_weight = state_dict[classifier_name + '.weight']
|
||||||
classifier_bias = state_dict[classifier_name + '.bias']
|
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
|
||||||
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
|
classifier_bias = state_dict[classifier_name + '.bias']
|
||||||
|
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
|
||||||
|
|
||||||
model.load_state_dict(state_dict, strict=strict)
|
model.load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
|
@ -49,14 +49,17 @@ default_cfgs = {
|
|||||||
'pit_b_224': _cfg(
|
'pit_b_224': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_820.pth'),
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_820.pth'),
|
||||||
'pit_ti_distilled_224': _cfg(
|
'pit_ti_distilled_224': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth'),
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth',
|
||||||
|
classifier=('head', 'head_dist')),
|
||||||
'pit_xs_distilled_224': _cfg(
|
'pit_xs_distilled_224': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth'),
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth',
|
||||||
|
classifier=('head', 'head_dist')),
|
||||||
'pit_s_distilled_224': _cfg(
|
'pit_s_distilled_224': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth'),
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth',
|
||||||
|
classifier=('head', 'head_dist')),
|
||||||
'pit_b_distilled_224': _cfg(
|
'pit_b_distilled_224': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth'),
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth',
|
||||||
|
classifier=('head', 'head_dist')),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -123,14 +123,17 @@ default_cfgs = {
|
|||||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
||||||
input_size=(3, 384, 384), crop_pct=1.0),
|
input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
'vit_deit_tiny_distilled_patch16_224': _cfg(
|
'vit_deit_tiny_distilled_patch16_224': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'),
|
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
|
||||||
|
classifier=('head', 'head_dist')),
|
||||||
'vit_deit_small_distilled_patch16_224': _cfg(
|
'vit_deit_small_distilled_patch16_224': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth'),
|
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
|
||||||
|
classifier=('head', 'head_dist')),
|
||||||
'vit_deit_base_distilled_patch16_224': _cfg(
|
'vit_deit_base_distilled_patch16_224': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', ),
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
|
||||||
|
classifier=('head', 'head_dist')),
|
||||||
'vit_deit_base_distilled_patch16_384': _cfg(
|
'vit_deit_base_distilled_patch16_384': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
||||||
input_size=(3, 384, 384), crop_pct=1.0),
|
input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -302,6 +305,7 @@ class VisionTransformer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||||
|
self.num_tokens = 2 if distilled else 1
|
||||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
|
||||||
if hybrid_backbone is not None:
|
if hybrid_backbone is not None:
|
||||||
@ -313,9 +317,8 @@ class VisionTransformer(nn.Module):
|
|||||||
num_patches = self.patch_embed.num_patches
|
num_patches = self.patch_embed.num_patches
|
||||||
|
|
||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) if distilled else None
|
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
|
||||||
num_tokens = 2 if distilled else 1
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
||||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + num_tokens, embed_dim))
|
|
||||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||||
@ -382,10 +385,10 @@ class VisionTransformer(nn.Module):
|
|||||||
x = self.pos_drop(x + self.pos_embed)
|
x = self.pos_drop(x + self.pos_embed)
|
||||||
x = self.blocks(x)
|
x = self.blocks(x)
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
if self.dist_token is not None:
|
if self.dist_token is None:
|
||||||
return x[:, 0], x[:, 1]
|
|
||||||
else:
|
|
||||||
return self.pre_logits(x[:, 0])
|
return self.pre_logits(x[:, 0])
|
||||||
|
else:
|
||||||
|
return x[:, 0], x[:, 1]
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.forward_features(x)
|
x = self.forward_features(x)
|
||||||
@ -401,15 +404,13 @@ class VisionTransformer(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def resize_pos_embed(posemb, posemb_new, token='class'):
|
def resize_pos_embed(posemb, posemb_new, num_tokens=1):
|
||||||
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
||||||
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
||||||
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
||||||
ntok_new = posemb_new.shape[1]
|
ntok_new = posemb_new.shape[1]
|
||||||
if token:
|
if num_tokens:
|
||||||
assert token in ('class', 'distill')
|
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
|
||||||
token_idx = 2 if token == 'distill' else 1
|
|
||||||
posemb_tok, posemb_grid = posemb[:, :token_idx], posemb[0, token_idx:]
|
|
||||||
ntok_new -= 1
|
ntok_new -= 1
|
||||||
else:
|
else:
|
||||||
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
||||||
@ -436,7 +437,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|||||||
v = v.reshape(O, -1, H, W)
|
v = v.reshape(O, -1, H, W)
|
||||||
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
||||||
# To resize pos embedding when using model at different size from pretrained weights
|
# To resize pos embedding when using model at different size from pretrained weights
|
||||||
v = resize_pos_embed(v, model.pos_embed, token='distill' if model.dist_token is not None else 'class')
|
v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1))
|
||||||
out_dict[k] = v
|
out_dict[k] = v
|
||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user