Speedup EMA
parent
541326eaf0
commit
cabdc251fe
|
@ -32,11 +32,14 @@ class ExponentialMovingAverage():
|
||||||
|
|
||||||
@paddle.no_grad()
|
@paddle.no_grad()
|
||||||
def _update(self, model, update_fn):
|
def _update(self, model, update_fn):
|
||||||
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
|
for ema_v, model_v in zip(self.module.state_dict().values(),
|
||||||
ema_v.set_value(update_fn(ema_v, model_v))
|
model.state_dict().values()):
|
||||||
|
ema_v.set_value(update_fn(ema_v.numpy(), model_v.numpy()))
|
||||||
|
|
||||||
def update(self, model):
|
def update(self, model):
|
||||||
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
|
self._update(
|
||||||
|
model,
|
||||||
|
update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
|
||||||
|
|
||||||
def set(self, model):
|
def set(self, model):
|
||||||
self._update(model, update_fn=lambda e, m: m)
|
self._update(model, update_fn=lambda e, m: m)
|
||||||
|
|
Loading…
Reference in New Issue