mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix ResNetV2 pretrained classifier issue. Fixes #540
This commit is contained in:
parent
de9dff933a
commit
2b49ab7a36
@ -132,7 +132,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
|
||||
def test_model_load_pretrained(model_name, batch_size):
|
||||
"""Create that pretrained weights load, verify support for in_chans != 3 while doing so."""
|
||||
in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change
|
||||
create_model(model_name, pretrained=True, in_chans=in_chans)
|
||||
create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=5)
|
||||
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=NON_STD_FILTERS))
|
||||
|
@ -365,6 +365,7 @@ class ResNetV2(nn.Module):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.head = ClassifierHead(
|
||||
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
||||
|
||||
@ -393,8 +394,9 @@ class ResNetV2(nn.Module):
|
||||
self.stem.conv.weight.copy_(stem_conv_w)
|
||||
self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
|
||||
self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
|
||||
self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
|
||||
self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
|
||||
if self.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
|
||||
self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
|
||||
self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
|
||||
for i, (sname, stage) in enumerate(self.stages.named_children()):
|
||||
for j, (bname, block) in enumerate(stage.blocks.named_children()):
|
||||
convname = 'standardized_conv2d'
|
||||
|
Loading…
x
Reference in New Issue
Block a user