parent
03795249c1
commit
fd7ef078fe
|
@ -1,32 +0,0 @@
|
|||
import functools
|
||||
import paddle
|
||||
|
||||
|
||||
def AMP_forward_decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(model, *args, **kwargs):
|
||||
if AMPForwardDecorator.amp_level:
|
||||
with paddle.amp.auto_cast(
|
||||
custom_black_list={
|
||||
"flatten_contiguous_range", "greater_than"
|
||||
},
|
||||
level=AMPForwardDecorator.amp_level):
|
||||
return func(model, *args, **kwargs)
|
||||
else:
|
||||
return func(model, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class AMPForwardDecorator(object):
|
||||
amp_level = None
|
||||
amp_eval = None
|
||||
|
||||
def __init__(self, forward_func):
|
||||
self.forward_func = forward_func
|
||||
|
||||
@functools.wraps
|
||||
def __call__(self, model_obj, *args, **kwargs):
|
||||
# print(type(self))
|
||||
# print(type(model_obj))
|
||||
return self.forward_func(model_obj, *args, **kwargs)
|
Loading…
Reference in New Issue