Fix pos_embed scaling for ViT and num_classes != 1000 for pretrained distilled deit and pit models. Fix #426 and fix #433
parent
a760a4c3f4
commit
7953e5d11a
|
@ -198,20 +198,24 @@ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filte
|
|||
_logger.warning(
|
||||
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
|
||||
|
||||
classifier_name = default_cfg.get('classifier', None)
|
||||
classifiers = default_cfg.get('classifier', None)
|
||||
label_offset = default_cfg.get('label_offset', 0)
|
||||
if classifier_name is not None:
|
||||
if classifiers is not None:
|
||||
if isinstance(classifiers, str):
|
||||
classifiers = (classifiers,)
|
||||
if num_classes != default_cfg['num_classes']:
|
||||
# completely discard fully connected if model num_classes doesn't match pretrained weights
|
||||
del state_dict[classifier_name + '.weight']
|
||||
del state_dict[classifier_name + '.bias']
|
||||
for classifier_name in classifiers:
|
||||
# completely discard fully connected if model num_classes doesn't match pretrained weights
|
||||
del state_dict[classifier_name + '.weight']
|
||||
del state_dict[classifier_name + '.bias']
|
||||
strict = False
|
||||
elif label_offset > 0:
|
||||
# special case for pretrained weights with an extra background class in pretrained weights
|
||||
classifier_weight = state_dict[classifier_name + '.weight']
|
||||
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
|
||||
classifier_bias = state_dict[classifier_name + '.bias']
|
||||
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
|
||||
for classifier_name in classifiers:
|
||||
# special case for pretrained weights with an extra background class in pretrained weights
|
||||
classifier_weight = state_dict[classifier_name + '.weight']
|
||||
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
|
||||
classifier_bias = state_dict[classifier_name + '.bias']
|
||||
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
|
||||
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
|
|
@ -49,14 +49,17 @@ default_cfgs = {
|
|||
'pit_b_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_820.pth'),
|
||||
'pit_ti_distilled_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth'),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'pit_xs_distilled_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth'),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'pit_s_distilled_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth'),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'pit_b_distilled_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth'),
|
||||
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -123,14 +123,17 @@ default_cfgs = {
|
|||
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_deit_tiny_distilled_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'),
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'vit_deit_small_distilled_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth'),
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'vit_deit_base_distilled_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', ),
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'vit_deit_base_distilled_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
|
||||
}
|
||||
|
||||
|
||||
|
@ -302,6 +305,7 @@ class VisionTransformer(nn.Module):
|
|||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_tokens = 2 if distilled else 1
|
||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
if hybrid_backbone is not None:
|
||||
|
@ -313,9 +317,8 @@ class VisionTransformer(nn.Module):
|
|||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) if distilled else None
|
||||
num_tokens = 2 if distilled else 1
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + num_tokens, embed_dim))
|
||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
|
@ -382,10 +385,10 @@ class VisionTransformer(nn.Module):
|
|||
x = self.pos_drop(x + self.pos_embed)
|
||||
x = self.blocks(x)
|
||||
x = self.norm(x)
|
||||
if self.dist_token is not None:
|
||||
return x[:, 0], x[:, 1]
|
||||
else:
|
||||
if self.dist_token is None:
|
||||
return self.pre_logits(x[:, 0])
|
||||
else:
|
||||
return x[:, 0], x[:, 1]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
@ -401,15 +404,13 @@ class VisionTransformer(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def resize_pos_embed(posemb, posemb_new, token='class'):
|
||||
def resize_pos_embed(posemb, posemb_new, num_tokens=1):
|
||||
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
||||
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
||||
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
||||
ntok_new = posemb_new.shape[1]
|
||||
if token:
|
||||
assert token in ('class', 'distill')
|
||||
token_idx = 2 if token == 'distill' else 1
|
||||
posemb_tok, posemb_grid = posemb[:, :token_idx], posemb[0, token_idx:]
|
||||
if num_tokens:
|
||||
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
|
||||
ntok_new -= 1
|
||||
else:
|
||||
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
||||
|
@ -436,7 +437,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|||
v = v.reshape(O, -1, H, W)
|
||||
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
||||
# To resize pos embedding when using model at different size from pretrained weights
|
||||
v = resize_pos_embed(v, model.pos_embed, token='distill' if model.dist_token is not None else 'class')
|
||||
v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1))
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
|
Loading…
Reference in New Issue