PaddleClas/ppcls/utils/amp.py

33 lines
891 B
Python

import functools
import paddle
def AMP_forward_decorator(func):
@functools.wraps(func)
def wrapper(model, *args, **kwargs):
if AMPForwardDecorator.amp_level:
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=AMPForwardDecorator.amp_level):
return func(model, *args, **kwargs)
else:
return func(model, *args, **kwargs)
return wrapper
class AMPForwardDecorator(object):
amp_level = None
amp_eval = None
def __init__(self, forward_func):
self.forward_func = forward_func
@functools.wraps
def __call__(self, model_obj, *args, **kwargs):
# print(type(self))
# print(type(model_obj))
return self.forward_func(model_obj, *args, **kwargs)