Fix a few issues loading pretrained vit/bit npz weights w/ num_classes=0 __init__ arg. Missed a few other small classifier handling detail on Mlp, GhostNet, Levit. Should fix #713
parent
dc422820ec
commit
b41cffaa93
|
@ -147,6 +147,15 @@ def test_model_default_cfgs(model_name, batch_size):
|
|||
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
|
||||
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
|
||||
|
||||
if 'pruned' not in model_name: # FIXME better pruned model handling
|
||||
# test classifier + global pool deletion via __init__
|
||||
model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval()
|
||||
outputs = model.forward(input_tensor)
|
||||
assert len(outputs.shape) == 4
|
||||
if not isinstance(model, timm.models.MobileNetV3) and not isinstance(model, timm.models.GhostNet):
|
||||
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
|
||||
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
|
||||
|
||||
# check classifier name matches default_cfg
|
||||
classifier = cfg['classifier']
|
||||
if not isinstance(classifier, (tuple, list)):
|
||||
|
@ -193,6 +202,13 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
|
|||
assert len(outputs.shape) == 2
|
||||
assert outputs.shape[1] == model.num_features
|
||||
|
||||
model = create_model(model_name, pretrained=False, num_classes=0).eval()
|
||||
outputs = model.forward(input_tensor)
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = outputs[0]
|
||||
assert len(outputs.shape) == 2
|
||||
assert outputs.shape[1] == model.num_features
|
||||
|
||||
# check classifier name matches default_cfg
|
||||
classifier = cfg['classifier']
|
||||
if not isinstance(classifier, (tuple, list)):
|
||||
|
@ -217,6 +233,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
|
|||
"""Create that pretrained weights load, verify support for in_chans != 3 while doing so."""
|
||||
in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change
|
||||
create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=5)
|
||||
create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=0)
|
||||
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=NON_STD_FILTERS))
|
||||
|
|
|
@ -182,7 +182,7 @@ class GhostNet(nn.Module):
|
|||
self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True)
|
||||
self.act2 = nn.ReLU(inplace=True)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
|
||||
self.classifier = Linear(out_chs, num_classes)
|
||||
self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def get_classifier(self):
|
||||
return self.classifier
|
||||
|
|
|
@ -542,7 +542,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|||
state_dict = state_dict['model']
|
||||
D = model.state_dict()
|
||||
for k in state_dict.keys():
|
||||
if D[k].ndim == 4 and state_dict[k].ndim == 2:
|
||||
if k in D and D[k].ndim == 4 and state_dict[k].ndim == 2:
|
||||
state_dict[k] = state_dict[k][:, :, None, None]
|
||||
return state_dict
|
||||
|
||||
|
|
|
@ -266,7 +266,7 @@ class MlpMixer(nn.Module):
|
|||
act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate)
|
||||
for _ in range(num_blocks)])
|
||||
self.norm = norm_layer(embed_dim)
|
||||
self.head = nn.Linear(embed_dim, self.num_classes) # zero init
|
||||
self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.init_weights(nlhb=nlhb)
|
||||
|
||||
|
|
|
@ -424,7 +424,8 @@ def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = 'resnet/
|
|||
model.stem.conv.weight.copy_(stem_conv_w)
|
||||
model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma']))
|
||||
model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta']))
|
||||
if model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
|
||||
if isinstance(model.head.fc, nn.Conv2d) and \
|
||||
model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
|
||||
model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel']))
|
||||
model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias']))
|
||||
for i, (sname, stage) in enumerate(model.stages.named_children()):
|
||||
|
|
|
@ -237,7 +237,6 @@ class Visformer(nn.Module):
|
|||
self.num_features = embed_dim if self.vit_stem else embed_dim * 2
|
||||
self.norm = norm_layer(self.num_features)
|
||||
self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
self.head = nn.Linear(self.num_features, num_classes)
|
||||
|
||||
# weights init
|
||||
if self.pos_embed:
|
||||
|
|
|
@ -448,7 +448,7 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
|||
model.pos_embed.copy_(pos_embed_w)
|
||||
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
||||
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
||||
if model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
||||
if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
||||
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
||||
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
||||
for i, block in enumerate(model.blocks.children()):
|
||||
|
|
Loading…
Reference in New Issue