mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix a few more issues related to #216 w/ TResNet (space2depth) and FP16 weights in wide resnets. Also don't completely dump pretrained weights in in_chans != 1 or 3 cases.
This commit is contained in:
parent
512b2dd645
commit
b1b6e7c361
@ -115,8 +115,9 @@ if 'GITHUB_ACTIONS' not in os.environ:
|
|||||||
@pytest.mark.parametrize('model_name', list_models(pretrained=True))
|
@pytest.mark.parametrize('model_name', list_models(pretrained=True))
|
||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
def test_model_load_pretrained(model_name, batch_size):
|
def test_model_load_pretrained(model_name, batch_size):
|
||||||
"""Run a single forward pass with each model"""
|
"""Create that pretrained weights load, verify support for in_chans != 3 while doing so."""
|
||||||
create_model(model_name, pretrained=True)
|
in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change
|
||||||
|
create_model(model_name, pretrained=True, in_chans=in_chans)
|
||||||
|
|
||||||
|
|
||||||
EXCLUDE_JIT_FILTERS = [
|
EXCLUDE_JIT_FILTERS = [
|
||||||
|
@ -1,5 +1,10 @@
|
|||||||
|
""" Model creation / weight loading / state_dict helpers
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import math
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
@ -86,11 +91,40 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
|
|||||||
|
|
||||||
if in_chans == 1:
|
if in_chans == 1:
|
||||||
conv1_name = cfg['first_conv']
|
conv1_name = cfg['first_conv']
|
||||||
_logger.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
|
_logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
|
||||||
conv1_weight = state_dict[conv1_name + '.weight']
|
conv1_weight = state_dict[conv1_name + '.weight']
|
||||||
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True)
|
# Some weights are in torch.half, ensure it's float for sum on CPU
|
||||||
|
conv1_type = conv1_weight.dtype
|
||||||
|
conv1_weight = conv1_weight.float()
|
||||||
|
O, I, J, K = conv1_weight.shape
|
||||||
|
if I > 3:
|
||||||
|
assert conv1_weight.shape[1] % 3 == 0
|
||||||
|
# For models with space2depth stems
|
||||||
|
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
|
||||||
|
conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
|
||||||
|
else:
|
||||||
|
conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
|
||||||
|
conv1_weight = conv1_weight.to(conv1_type)
|
||||||
|
state_dict[conv1_name + '.weight'] = conv1_weight
|
||||||
elif in_chans != 3:
|
elif in_chans != 3:
|
||||||
assert False, "Invalid in_chans for pretrained weights"
|
conv1_name = cfg['first_conv']
|
||||||
|
conv1_weight = state_dict[conv1_name + '.weight']
|
||||||
|
conv1_type = conv1_weight.dtype
|
||||||
|
conv1_weight = conv1_weight.float()
|
||||||
|
O, I, J, K = conv1_weight.shape
|
||||||
|
if I != 3:
|
||||||
|
_logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
|
||||||
|
del state_dict[conv1_name + '.weight']
|
||||||
|
strict = False
|
||||||
|
else:
|
||||||
|
# NOTE this strategy should be better than random init, but there could be other combinations of
|
||||||
|
# the original RGB input layer weights that'd work better for specific cases.
|
||||||
|
_logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
|
||||||
|
repeat = int(math.ceil(in_chans / 3))
|
||||||
|
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
|
||||||
|
conv1_weight *= (3 / float(in_chans))
|
||||||
|
conv1_weight = conv1_weight.to(conv1_type)
|
||||||
|
state_dict[conv1_name + '.weight'] = conv1_weight
|
||||||
|
|
||||||
classifier_name = cfg['classifier']
|
classifier_name = cfg['classifier']
|
||||||
if num_classes == 1000 and cfg['num_classes'] == 1001:
|
if num_classes == 1000 and cfg['num_classes'] == 1001:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user