mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
with this update one can tune the kind of logs generated by timm but training and inference traces are unchanged
200 lines
7.7 KiB
Python
200 lines
7.7 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from copy import deepcopy
|
|
import torch.utils.model_zoo as model_zoo
|
|
import os
|
|
import logging
|
|
from collections import OrderedDict
|
|
from timm.models.layers.conv2d_same import Conv2dSame
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def load_state_dict(checkpoint_path, use_ema=False):
|
|
if checkpoint_path and os.path.isfile(checkpoint_path):
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
state_dict_key = 'state_dict'
|
|
if isinstance(checkpoint, dict):
|
|
if use_ema and 'state_dict_ema' in checkpoint:
|
|
state_dict_key = 'state_dict_ema'
|
|
if state_dict_key and state_dict_key in checkpoint:
|
|
new_state_dict = OrderedDict()
|
|
for k, v in checkpoint[state_dict_key].items():
|
|
# strip `module.` prefix
|
|
name = k[7:] if k.startswith('module') else k
|
|
new_state_dict[name] = v
|
|
state_dict = new_state_dict
|
|
else:
|
|
state_dict = checkpoint
|
|
logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
|
|
return state_dict
|
|
else:
|
|
logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
|
raise FileNotFoundError()
|
|
|
|
|
|
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
|
|
state_dict = load_state_dict(checkpoint_path, use_ema)
|
|
model.load_state_dict(state_dict, strict=strict)
|
|
|
|
|
|
def resume_checkpoint(model, checkpoint_path):
|
|
other_state = {}
|
|
resume_epoch = None
|
|
if os.path.isfile(checkpoint_path):
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
|
new_state_dict = OrderedDict()
|
|
for k, v in checkpoint['state_dict'].items():
|
|
name = k[7:] if k.startswith('module') else k
|
|
new_state_dict[name] = v
|
|
model.load_state_dict(new_state_dict)
|
|
if 'optimizer' in checkpoint:
|
|
other_state['optimizer'] = checkpoint['optimizer']
|
|
if 'amp' in checkpoint:
|
|
other_state['amp'] = checkpoint['amp']
|
|
if 'epoch' in checkpoint:
|
|
resume_epoch = checkpoint['epoch']
|
|
if 'version' in checkpoint and checkpoint['version'] > 1:
|
|
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
|
|
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
|
else:
|
|
model.load_state_dict(checkpoint)
|
|
logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
|
return other_state, resume_epoch
|
|
else:
|
|
logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
|
raise FileNotFoundError()
|
|
|
|
|
|
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True):
|
|
if cfg is None:
|
|
cfg = getattr(model, 'default_cfg')
|
|
if cfg is None or 'url' not in cfg or not cfg['url']:
|
|
logger.warning("Pretrained model URL is invalid, using random initialization.")
|
|
return
|
|
|
|
state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
|
|
|
|
if in_chans == 1:
|
|
conv1_name = cfg['first_conv']
|
|
logger.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
|
|
conv1_weight = state_dict[conv1_name + '.weight']
|
|
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True)
|
|
elif in_chans != 3:
|
|
assert False, "Invalid in_chans for pretrained weights"
|
|
|
|
classifier_name = cfg['classifier']
|
|
if num_classes == 1000 and cfg['num_classes'] == 1001:
|
|
# special case for imagenet trained models with extra background class in pretrained weights
|
|
classifier_weight = state_dict[classifier_name + '.weight']
|
|
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
|
|
classifier_bias = state_dict[classifier_name + '.bias']
|
|
state_dict[classifier_name + '.bias'] = classifier_bias[1:]
|
|
elif num_classes != cfg['num_classes']:
|
|
# completely discard fully connected for all other differences between pretrained and created model
|
|
del state_dict[classifier_name + '.weight']
|
|
del state_dict[classifier_name + '.bias']
|
|
strict = False
|
|
|
|
if filter_fn is not None:
|
|
state_dict = filter_fn(state_dict)
|
|
|
|
model.load_state_dict(state_dict, strict=strict)
|
|
|
|
|
|
def extract_layer(model, layer):
|
|
layer = layer.split('.')
|
|
module = model
|
|
if hasattr(model, 'module') and layer[0] != 'module':
|
|
module = model.module
|
|
if not hasattr(model, 'module') and layer[0] == 'module':
|
|
layer = layer[1:]
|
|
for l in layer:
|
|
if hasattr(module, l):
|
|
if not l.isdigit():
|
|
module = getattr(module, l)
|
|
else:
|
|
module = module[int(l)]
|
|
else:
|
|
return module
|
|
return module
|
|
|
|
|
|
def set_layer(model, layer, val):
|
|
layer = layer.split('.')
|
|
module = model
|
|
if hasattr(model, 'module') and layer[0] != 'module':
|
|
module = model.module
|
|
lst_index = 0
|
|
module2 = module
|
|
for l in layer:
|
|
if hasattr(module2, l):
|
|
if not l.isdigit():
|
|
module2 = getattr(module2, l)
|
|
else:
|
|
module2 = module2[int(l)]
|
|
lst_index += 1
|
|
lst_index -= 1
|
|
for l in layer[:lst_index]:
|
|
if not l.isdigit():
|
|
module = getattr(module, l)
|
|
else:
|
|
module = module[int(l)]
|
|
l = layer[lst_index]
|
|
setattr(module, l, val)
|
|
|
|
|
|
def adapt_model_from_string(parent_module, model_string):
|
|
separator = '***'
|
|
state_dict = {}
|
|
lst_shape = model_string.split(separator)
|
|
for k in lst_shape:
|
|
k = k.split(':')
|
|
key = k[0]
|
|
shape = k[1][1:-1].split(',')
|
|
if shape[0] != '':
|
|
state_dict[key] = [int(i) for i in shape]
|
|
|
|
new_module = deepcopy(parent_module)
|
|
for n, m in parent_module.named_modules():
|
|
old_module = extract_layer(parent_module, n)
|
|
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
|
|
if isinstance(old_module, Conv2dSame):
|
|
conv = Conv2dSame
|
|
else:
|
|
conv = nn.Conv2d
|
|
s = state_dict[n + '.weight']
|
|
in_channels = s[1]
|
|
out_channels = s[0]
|
|
g = 1
|
|
if old_module.groups > 1:
|
|
in_channels = out_channels
|
|
g = in_channels
|
|
new_conv = conv(
|
|
in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
|
|
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
|
|
groups=g, stride=old_module.stride)
|
|
set_layer(new_module, n, new_conv)
|
|
if isinstance(old_module, nn.BatchNorm2d):
|
|
new_bn = nn.BatchNorm2d(
|
|
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
|
affine=old_module.affine, track_running_stats=True)
|
|
set_layer(new_module, n, new_bn)
|
|
if isinstance(old_module, nn.Linear):
|
|
new_fc = nn.Linear(
|
|
in_features=state_dict[n + '.weight'][1], out_features=old_module.out_features,
|
|
bias=old_module.bias is not None)
|
|
set_layer(new_module, n, new_fc)
|
|
new_module.eval()
|
|
parent_module.eval()
|
|
|
|
return new_module
|
|
|
|
|
|
def adapt_model_from_file(parent_module, model_variant):
|
|
adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
|
|
with open(adapt_file, 'r') as f:
|
|
return adapt_model_from_string(parent_module, f.read().strip())
|