[Fix]: fix data type in fused-bias-leakyrelu for apex fp16 training (#981)

pull/989/head
Rui Xu 2021-04-24 19:10:13 +08:00 committed by GitHub
parent 9649a9ad22
commit 841a078e69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 5 deletions

View File

@ -45,10 +45,9 @@ class FusedBiasLeakyReLUFunctionBackward(Function):
# The second order deviation, in fact, contains two parts, while the
# the first part is zero. Thus, we direct consider the second part
# which is similar with the first order deviation in implementation.
gradgrad_out = ext_module.fused_bias_leakyrelu(gradgrad_input,
gradgrad_bias, out, 3,
1, ctx.negative_slope,
ctx.scale)
gradgrad_out = ext_module.fused_bias_leakyrelu(
gradgrad_input, gradgrad_bias.to(out.dtype), out, 3, 1,
ctx.negative_slope, ctx.scale)
return gradgrad_out, None, None, None
@ -139,7 +138,8 @@ def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
if not input.is_cuda:
return bias_leakyrelu_ref(input, bias, negative_slope, scale)
return FusedBiasLeakyReLUFunction.apply(input, bias, negative_slope, scale)
return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype),
negative_slope, scale)
def bias_leakyrelu_ref(x, bias, negative_slope=0.2, scale=2**0.5):