mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
33 lines
891 B
Python
33 lines
891 B
Python
|
import functools
|
||
|
import paddle
|
||
|
|
||
|
|
||
|
def AMP_forward_decorator(func):
|
||
|
@functools.wraps(func)
|
||
|
def wrapper(model, *args, **kwargs):
|
||
|
if AMPForwardDecorator.amp_level:
|
||
|
with paddle.amp.auto_cast(
|
||
|
custom_black_list={
|
||
|
"flatten_contiguous_range", "greater_than"
|
||
|
},
|
||
|
level=AMPForwardDecorator.amp_level):
|
||
|
return func(model, *args, **kwargs)
|
||
|
else:
|
||
|
return func(model, *args, **kwargs)
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
class AMPForwardDecorator(object):
|
||
|
amp_level = None
|
||
|
amp_eval = None
|
||
|
|
||
|
def __init__(self, forward_func):
|
||
|
self.forward_func = forward_func
|
||
|
|
||
|
@functools.wraps
|
||
|
def __call__(self, model_obj, *args, **kwargs):
|
||
|
# print(type(self))
|
||
|
# print(type(model_obj))
|
||
|
return self.forward_func(model_obj, *args, **kwargs)
|