mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix #2127 move to ema device
This commit is contained in:
parent
e25bbfceec
commit
24f6d4f7f8
@ -230,9 +230,9 @@ class ModelEmaV3(nn.Module):
|
||||
else:
|
||||
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
|
||||
if ema_v.is_floating_point():
|
||||
ema_v.lerp_(model_v, weight=1. - decay)
|
||||
ema_v.lerp_(model_v.to(device=self.device), weight=1. - decay)
|
||||
else:
|
||||
ema_v.copy_(model_v)
|
||||
ema_v.copy_(model_v.to(device=self.device))
|
||||
|
||||
def apply_update_no_buffers_(self, model, decay: float):
|
||||
# interpolate parameters, copy buffers
|
||||
@ -246,7 +246,7 @@ class ModelEmaV3(nn.Module):
|
||||
torch._foreach_add_(ema_params, model_params, alpha=1 - decay)
|
||||
else:
|
||||
for ema_p, model_p in zip(ema_params, model_params):
|
||||
ema_p.lerp_(model_p, weight=1. - decay)
|
||||
ema_p.lerp_(model_p.to(device=self.device), weight=1. - decay)
|
||||
|
||||
for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
|
||||
ema_b.copy_(model_b.to(device=self.device))
|
||||
|
Loading…
x
Reference in New Issue
Block a user