51 lines
1.3 KiB
Python
51 lines
1.3 KiB
Python
|
from functools import partial
|
||
|
import contextlib
|
||
|
import paddle
|
||
|
|
||
|
|
||
|
class AutoCast:
|
||
|
def __init__(self,
|
||
|
use_amp=False,
|
||
|
amp_level="O1",
|
||
|
use_promote=False,
|
||
|
amp_eval=False):
|
||
|
self.use_amp = use_amp
|
||
|
self.amp_eval = amp_eval
|
||
|
|
||
|
if self.use_amp:
|
||
|
self.cast_context = partial(paddle.amp.auto_cast, level=amp_level)
|
||
|
|
||
|
def __call__(self, is_eval=False):
|
||
|
if self.use_amp:
|
||
|
# not is_eval: cast for all training
|
||
|
# is_eval and self.amp_eval: cast for evaluation only when amp_eval is True
|
||
|
if not is_eval or (is_eval and self.amp_eval):
|
||
|
return self.cast_context()
|
||
|
|
||
|
return contextlib.nullcontext()
|
||
|
|
||
|
|
||
|
def build_scaler(use_amp=False, scale_loss=1.0,
|
||
|
use_dynamic_loss_scaling=False):
|
||
|
class Foo:
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
def scale(self, loss):
|
||
|
return loss
|
||
|
|
||
|
def step(self, optimizer):
|
||
|
optimizer.step()
|
||
|
|
||
|
def update(self):
|
||
|
return
|
||
|
|
||
|
def minimize(self, optimizer, loss):
|
||
|
optimizer.step()
|
||
|
|
||
|
if use_amp:
|
||
|
return paddle.amp.GradScaler(
|
||
|
init_loss_scaling=scale_loss,
|
||
|
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
|
||
|
return Foo()
|