60 lines
2.0 KiB
Python
60 lines
2.0 KiB
Python
import logging
|
|
import os
|
|
import time
|
|
import pickle
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from utils.distributed import is_main_process
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
NORM_MODULES = [
|
|
torch.nn.BatchNorm1d,
|
|
torch.nn.BatchNorm2d,
|
|
torch.nn.BatchNorm3d,
|
|
torch.nn.SyncBatchNorm,
|
|
# NaiveSyncBatchNorm inherits from BatchNorm2d
|
|
torch.nn.GroupNorm,
|
|
torch.nn.InstanceNorm1d,
|
|
torch.nn.InstanceNorm2d,
|
|
torch.nn.InstanceNorm3d,
|
|
torch.nn.LayerNorm,
|
|
torch.nn.LocalResponseNorm,
|
|
]
|
|
|
|
def register_norm_module(cls):
|
|
NORM_MODULES.append(cls)
|
|
return cls
|
|
|
|
def align_and_update_state_dicts(model_state_dict, ckpt_state_dict):
|
|
model_keys = sorted(model_state_dict.keys())
|
|
ckpt_keys = sorted(ckpt_state_dict.keys())
|
|
result_dicts = {}
|
|
matched_log = []
|
|
unmatched_log = []
|
|
unloaded_log = []
|
|
for model_key in model_keys:
|
|
model_weight = model_state_dict[model_key]
|
|
if model_key in ckpt_keys:
|
|
ckpt_weight = ckpt_state_dict[model_key]
|
|
if model_weight.shape == ckpt_weight.shape:
|
|
result_dicts[model_key] = ckpt_weight
|
|
ckpt_keys.pop(ckpt_keys.index(model_key))
|
|
matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape))
|
|
else:
|
|
unmatched_log.append("*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape))
|
|
else:
|
|
unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape))
|
|
|
|
if is_main_process():
|
|
for info in matched_log:
|
|
logger.info(info)
|
|
for info in unloaded_log:
|
|
logger.warning(info)
|
|
for key in ckpt_keys:
|
|
logger.warning("$UNUSED$ {}, Ckpt Shape: {}".format(key, ckpt_state_dict[key].shape))
|
|
for info in unmatched_log:
|
|
logger.warning(info)
|
|
return result_dicts |