mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Use in-place operations for EMA
This commit is contained in:
parent
25ffac6880
commit
6ec5cd6a99
@ -117,10 +117,15 @@ class ModelEmaV2(nn.Module):
|
||||
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
|
||||
if self.device is not None:
|
||||
model_v = model_v.to(device=self.device)
|
||||
ema_v.copy_(update_fn(ema_v, model_v))
|
||||
update_fn(ema_v, model_v)
|
||||
|
||||
def update(self, model):
|
||||
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
|
||||
|
||||
def ema_update(e, m):
|
||||
if m.is_floating_point():
|
||||
e.mul_(self.decay).add_(m, alpha=1 - self.decay)
|
||||
|
||||
self._update(model, update_fn=ema_update)
|
||||
|
||||
def set(self, model):
|
||||
self._update(model, update_fn=lambda e, m: m)
|
||||
self._update(model, update_fn=lambda e, m: e.copy_(m))
|
||||
|
Loading…
x
Reference in New Issue
Block a user