14 lines
300 B
Python
Raw Normal View History

2020-09-07 13:56:00 -07:00
""" Model / state_dict utils
Hacked together by / Copyright 2020 Ross Wightman
"""
from .model_ema import ModelEma
def unwrap_model(model):
2020-11-16 12:51:52 -08:00
return model.module if hasattr(model, 'module') else model
2020-09-07 13:56:00 -07:00
def get_state_dict(model, unwrap_fn=unwrap_model):
return unwrap_fn(model).state_dict()