Fix #2127 move to ema device

pull/2162/head
Ross Wightman 2024-04-10 21:29:09 -07:00
parent e25bbfceec
commit 24f6d4f7f8
1 changed files with 3 additions and 3 deletions

View File

@ -230,9 +230,9 @@ class ModelEmaV3(nn.Module):
else:
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if ema_v.is_floating_point():
ema_v.lerp_(model_v, weight=1. - decay)
ema_v.lerp_(model_v.to(device=self.device), weight=1. - decay)
else:
ema_v.copy_(model_v)
ema_v.copy_(model_v.to(device=self.device))
def apply_update_no_buffers_(self, model, decay: float):
# interpolate parameters, copy buffers
@ -246,7 +246,7 @@ class ModelEmaV3(nn.Module):
torch._foreach_add_(ema_params, model_params, alpha=1 - decay)
else:
for ema_p, model_p in zip(ema_params, model_params):
ema_p.lerp_(model_p, weight=1. - decay)
ema_p.lerp_(model_p.to(device=self.device), weight=1. - decay)
for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
ema_b.copy_(model_b.to(device=self.device))