166 lines
6.3 KiB
Python
166 lines
6.3 KiB
Python
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
import paddle
|
|
import paddle.fluid as fluid
|
|
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
|
|
from paddle.fluid.framework import Program, program_guard, name_scope, default_main_program
|
|
from paddle.fluid import unique_name, layers
|
|
|
|
|
|
class ExponentialMovingAverage(object):
|
|
def __init__(self,
|
|
decay=0.999,
|
|
thres_steps=None,
|
|
zero_debias=False,
|
|
name=None):
|
|
self._decay = decay
|
|
self._thres_steps = thres_steps
|
|
self._name = name if name is not None else ''
|
|
self._decay_var = self._get_ema_decay()
|
|
|
|
self._params_tmps = []
|
|
for param in default_main_program().global_block().all_parameters():
|
|
if param.do_model_average != False:
|
|
tmp = param.block.create_var(
|
|
name=unique_name.generate(".".join(
|
|
[self._name + param.name, 'ema_tmp'])),
|
|
dtype=param.dtype,
|
|
persistable=False,
|
|
stop_gradient=True)
|
|
self._params_tmps.append((param, tmp))
|
|
|
|
self._ema_vars = {}
|
|
for param, tmp in self._params_tmps:
|
|
with param.block.program._optimized_guard(
|
|
[param, tmp]), name_scope('moving_average'):
|
|
self._ema_vars[param.name] = self._create_ema_vars(param)
|
|
|
|
self.apply_program = Program()
|
|
block = self.apply_program.global_block()
|
|
with program_guard(main_program=self.apply_program):
|
|
decay_pow = self._get_decay_pow(block)
|
|
for param, tmp in self._params_tmps:
|
|
param = block._clone_variable(param)
|
|
tmp = block._clone_variable(tmp)
|
|
ema = block._clone_variable(self._ema_vars[param.name])
|
|
layers.assign(input=param, output=tmp)
|
|
# bias correction
|
|
if zero_debias:
|
|
ema = ema / (1.0 - decay_pow)
|
|
layers.assign(input=ema, output=param)
|
|
|
|
self.restore_program = Program()
|
|
block = self.restore_program.global_block()
|
|
with program_guard(main_program=self.restore_program):
|
|
for param, tmp in self._params_tmps:
|
|
tmp = block._clone_variable(tmp)
|
|
param = block._clone_variable(param)
|
|
layers.assign(input=tmp, output=param)
|
|
|
|
def _get_ema_decay(self):
|
|
with default_main_program()._lr_schedule_guard():
|
|
decay_var = layers.tensor.create_global_var(
|
|
shape=[1],
|
|
value=self._decay,
|
|
dtype='float32',
|
|
persistable=True,
|
|
name="scheduled_ema_decay_rate")
|
|
|
|
if self._thres_steps is not None:
|
|
decay_t = (self._thres_steps + 1.0) / (self._thres_steps + 10.0)
|
|
with layers.control_flow.Switch() as switch:
|
|
with switch.case(decay_t < self._decay):
|
|
layers.tensor.assign(decay_t, decay_var)
|
|
with switch.default():
|
|
layers.tensor.assign(
|
|
np.array(
|
|
[self._decay], dtype=np.float32),
|
|
decay_var)
|
|
return decay_var
|
|
|
|
def _get_decay_pow(self, block):
|
|
global_steps = layers.learning_rate_scheduler._decay_step_counter()
|
|
decay_var = block._clone_variable(self._decay_var)
|
|
decay_pow_acc = layers.elementwise_pow(decay_var, global_steps + 1)
|
|
return decay_pow_acc
|
|
|
|
def _create_ema_vars(self, param):
|
|
param_ema = layers.create_global_var(
|
|
name=unique_name.generate(self._name + param.name + '_ema'),
|
|
shape=param.shape,
|
|
value=0.0,
|
|
dtype=param.dtype,
|
|
persistable=True)
|
|
|
|
return param_ema
|
|
|
|
def update(self):
|
|
"""
|
|
Update Exponential Moving Average. Should only call this method in
|
|
train program.
|
|
"""
|
|
param_master_emas = []
|
|
for param, tmp in self._params_tmps:
|
|
with param.block.program._optimized_guard(
|
|
[param, tmp]), name_scope('moving_average'):
|
|
param_ema = self._ema_vars[param.name]
|
|
if param.name + '.master' in self._ema_vars:
|
|
master_ema = self._ema_vars[param.name + '.master']
|
|
param_master_emas.append([param_ema, master_ema])
|
|
else:
|
|
ema_t = param_ema * self._decay_var + param * (
|
|
1 - self._decay_var)
|
|
layers.assign(input=ema_t, output=param_ema)
|
|
|
|
# for fp16 params
|
|
for param_ema, master_ema in param_master_emas:
|
|
default_main_program().global_block().append_op(
|
|
type="cast",
|
|
inputs={"X": master_ema},
|
|
outputs={"Out": param_ema},
|
|
attrs={
|
|
"in_dtype": master_ema.dtype,
|
|
"out_dtype": param_ema.dtype
|
|
})
|
|
|
|
@signature_safe_contextmanager
|
|
def apply(self, executor, need_restore=True):
|
|
"""
|
|
Apply moving average to parameters for evaluation.
|
|
Args:
|
|
executor (Executor): The Executor to execute applying.
|
|
need_restore (bool): Whether to restore parameters after applying.
|
|
"""
|
|
executor.run(self.apply_program)
|
|
try:
|
|
yield
|
|
finally:
|
|
if need_restore:
|
|
self.restore(executor)
|
|
|
|
def restore(self, executor):
|
|
"""Restore parameters.
|
|
Args:
|
|
executor (Executor): The Executor to execute restoring.
|
|
"""
|
|
executor.run(self.restore_program)
|