mirror of https://github.com/PyRetri/PyRetri.git
51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import os
|
|
|
|
import torch.nn as nn
|
|
from torch.nn import Parameter
|
|
|
|
from torchvision.models.utils import load_state_dict_from_url
|
|
|
|
from typing import Dict
|
|
|
|
def ensure_dir(path: str) -> None:
|
|
"""
|
|
Check if a directory exists, if not, create a new one.
|
|
|
|
Args:
|
|
path (str): the path of the directory.
|
|
"""
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
|
|
|
|
def load_state_dict(model: nn.Module, state_dict: Dict) -> None:
|
|
"""
|
|
Load parameters regardless the shape of parameters with the same name need to match,
|
|
which is a slight modification to load_state_dict of pytorch.
|
|
|
|
Args:
|
|
model (nn.Module): the model for extracting features.
|
|
state_dict (Dict): a dict of model parameters.
|
|
"""
|
|
own_state = model.state_dict()
|
|
success_keys = list()
|
|
for name, param in state_dict.items():
|
|
if name in own_state:
|
|
if isinstance(param, Parameter):
|
|
# backwards compatibility for serialized parameters
|
|
param = param.data
|
|
try:
|
|
own_state[name].copy_(param)
|
|
success_keys.append(name)
|
|
except Exception:
|
|
print("[LoadStateDict]: shape mismatch in parameter {}, {} vs {}".format(
|
|
name, own_state[name].size(), param.size()
|
|
))
|
|
else:
|
|
print("[LoadStateDict]: " + 'unexpected key "{}" in state_dict'.format(name))
|
|
missing = set(own_state.keys()) - set(success_keys)
|
|
if len(missing) > 0:
|
|
print("[LoadStateDict]: " + "missing keys or mismatch param in state_dict: {}".format(missing))
|