Fix #1912 CoaT model not loading w/ return_interm_layers

This commit is contained in:
Ross Wightman 2023-08-10 11:15:58 -07:00
parent c692715388
commit 3a44e6c602

View File

@ -690,8 +690,11 @@ def checkpoint_filter_fn(state_dict, model):
for k, v in state_dict.items():
# original model had unused norm layers, removing them requires filtering pretrained checkpoints
if k.startswith('norm1') or \
(model.norm2 is None and k.startswith('norm2')) or \
(model.norm3 is None and k.startswith('norm3')):
(k.startswith('norm2') and getattr(model, 'norm2', None) is None) or \
(k.startswith('norm3') and getattr(model, 'norm3', None) is None) or \
(k.startswith('norm4') and getattr(model, 'norm4', None) is None) or \
(k.startswith('aggregate') and getattr(model, 'aggregate', None) is None) or \
(k.startswith('head') and getattr(model, 'head', None) is None):
continue
out_dict[k] = v
return out_dict