fix ema: set_value() -> paddle.assign()

pull/2820/head
gaotingquan 2023-05-25 02:37:47 +00:00 committed by cuicheng01
parent 2823e48be5
commit f67cfe2c2a
1 changed files with 7 additions and 4 deletions

View File

@ -32,11 +32,14 @@ class ExponentialMovingAverage():
@paddle.no_grad()
def _update(self, model, update_fn):
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
ema_v.set_value(update_fn(ema_v, model_v))
for ema_v, model_v in zip(self.module.state_dict().values(),
model.state_dict().values()):
paddle.assign(update_fn(ema_v, model_v), ema_v)
def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
self._update(
model,
update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
def set(self, model):
self._update(model, update_fn=lambda e, m: m)
self._update(model, update_fn=lambda e, m: m)