Fix a few more issues related to #216 w/ TResNet (space2depth) and FP16 weights in wide resnets. Also don't completely dump pretrained weights in in_chans != 1 or 3 cases.
parent
512b2dd645
commit
b1b6e7c361
|
@ -115,8 +115,9 @@ if 'GITHUB_ACTIONS' not in os.environ:
|
|||
@pytest.mark.parametrize('model_name', list_models(pretrained=True))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_load_pretrained(model_name, batch_size):
|
||||
"""Run a single forward pass with each model"""
|
||||
create_model(model_name, pretrained=True)
|
||||
"""Create that pretrained weights load, verify support for in_chans != 3 while doing so."""
|
||||
in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change
|
||||
create_model(model_name, pretrained=True, in_chans=in_chans)
|
||||
|
||||
|
||||
EXCLUDE_JIT_FILTERS = [
|
||||
|
|
|
@ -1,5 +1,10 @@
|
|||
""" Model creation / weight loading / state_dict helpers
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Callable
|
||||
|
@ -86,11 +91,40 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
|
|||
|
||||
if in_chans == 1:
|
||||
conv1_name = cfg['first_conv']
|
||||
_logger.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
|
||||
_logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
|
||||
conv1_weight = state_dict[conv1_name + '.weight']
|
||||
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True)
|
||||
# Some weights are in torch.half, ensure it's float for sum on CPU
|
||||
conv1_type = conv1_weight.dtype
|
||||
conv1_weight = conv1_weight.float()
|
||||
O, I, J, K = conv1_weight.shape
|
||||
if I > 3:
|
||||
assert conv1_weight.shape[1] % 3 == 0
|
||||
# For models with space2depth stems
|
||||
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
|
||||
conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
|
||||
else:
|
||||
conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
|
||||
conv1_weight = conv1_weight.to(conv1_type)
|
||||
state_dict[conv1_name + '.weight'] = conv1_weight
|
||||
elif in_chans != 3:
|
||||
assert False, "Invalid in_chans for pretrained weights"
|
||||
conv1_name = cfg['first_conv']
|
||||
conv1_weight = state_dict[conv1_name + '.weight']
|
||||
conv1_type = conv1_weight.dtype
|
||||
conv1_weight = conv1_weight.float()
|
||||
O, I, J, K = conv1_weight.shape
|
||||
if I != 3:
|
||||
_logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
|
||||
del state_dict[conv1_name + '.weight']
|
||||
strict = False
|
||||
else:
|
||||
# NOTE this strategy should be better than random init, but there could be other combinations of
|
||||
# the original RGB input layer weights that'd work better for specific cases.
|
||||
_logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
|
||||
repeat = int(math.ceil(in_chans / 3))
|
||||
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
|
||||
conv1_weight *= (3 / float(in_chans))
|
||||
conv1_weight = conv1_weight.to(conv1_type)
|
||||
state_dict[conv1_name + '.weight'] = conv1_weight
|
||||
|
||||
classifier_name = cfg['classifier']
|
||||
if num_classes == 1000 and cfg['num_classes'] == 1001:
|
||||
|
|
Loading…
Reference in New Issue