mirror of https://github.com/YifanXu74/MQ-Det.git
47 lines
1.7 KiB
Python
47 lines
1.7 KiB
Python
from copy import deepcopy
|
|
from collections import OrderedDict
|
|
import torch
|
|
|
|
|
|
class ModelEma:
|
|
def __init__(self, model, decay=0.9999, device=''):
|
|
self.ema = deepcopy(model)
|
|
self.ema.eval()
|
|
self.decay = decay
|
|
self.device = device
|
|
if device:
|
|
self.ema.to(device=device)
|
|
self.ema_is_dp = hasattr(self.ema, 'module')
|
|
for p in self.ema.parameters():
|
|
p.requires_grad_(False)
|
|
|
|
def load_checkpoint(self, checkpoint):
|
|
if isinstance(checkpoint, str):
|
|
checkpoint = torch.load(checkpoint)
|
|
|
|
assert isinstance(checkpoint, dict)
|
|
if 'model_ema' in checkpoint:
|
|
new_state_dict = OrderedDict()
|
|
for k, v in checkpoint['model_ema'].items():
|
|
if self.ema_is_dp:
|
|
name = k if k.startswith('module') else 'module.' + k
|
|
else:
|
|
name = k.replace('module.', '') if k.startswith('module') else k
|
|
new_state_dict[name] = v
|
|
self.ema.load_state_dict(new_state_dict)
|
|
|
|
def state_dict(self):
|
|
return self.ema.state_dict()
|
|
|
|
def update(self, model):
|
|
pre_module = hasattr(model, 'module') and not self.ema_is_dp
|
|
with torch.no_grad():
|
|
curr_msd = model.state_dict()
|
|
for k, ema_v in self.ema.state_dict().items():
|
|
k = 'module.' + k if pre_module else k
|
|
model_v = curr_msd[k].detach()
|
|
if self.device:
|
|
model_v = model_v.to(device=self.device)
|
|
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
|
|
|