Fix #2127 move to ema device
parent
e25bbfceec
commit
24f6d4f7f8
|
@ -230,9 +230,9 @@ class ModelEmaV3(nn.Module):
|
|||
else:
|
||||
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
|
||||
if ema_v.is_floating_point():
|
||||
ema_v.lerp_(model_v, weight=1. - decay)
|
||||
ema_v.lerp_(model_v.to(device=self.device), weight=1. - decay)
|
||||
else:
|
||||
ema_v.copy_(model_v)
|
||||
ema_v.copy_(model_v.to(device=self.device))
|
||||
|
||||
def apply_update_no_buffers_(self, model, decay: float):
|
||||
# interpolate parameters, copy buffers
|
||||
|
@ -246,7 +246,7 @@ class ModelEmaV3(nn.Module):
|
|||
torch._foreach_add_(ema_params, model_params, alpha=1 - decay)
|
||||
else:
|
||||
for ema_p, model_p in zip(ema_params, model_params):
|
||||
ema_p.lerp_(model_p, weight=1. - decay)
|
||||
ema_p.lerp_(model_p.to(device=self.device), weight=1. - decay)
|
||||
|
||||
for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
|
||||
ema_b.copy_(model_b.to(device=self.device))
|
||||
|
|
Loading…
Reference in New Issue