mirror of https://github.com/open-mmlab/mmcv.git
[Fix]: fix data type in fused-bias-leakyrelu for apex fp16 training (#981)
parent
9649a9ad22
commit
841a078e69
|
@ -45,10 +45,9 @@ class FusedBiasLeakyReLUFunctionBackward(Function):
|
|||
# The second order deviation, in fact, contains two parts, while the
|
||||
# the first part is zero. Thus, we direct consider the second part
|
||||
# which is similar with the first order deviation in implementation.
|
||||
gradgrad_out = ext_module.fused_bias_leakyrelu(gradgrad_input,
|
||||
gradgrad_bias, out, 3,
|
||||
1, ctx.negative_slope,
|
||||
ctx.scale)
|
||||
gradgrad_out = ext_module.fused_bias_leakyrelu(
|
||||
gradgrad_input, gradgrad_bias.to(out.dtype), out, 3, 1,
|
||||
ctx.negative_slope, ctx.scale)
|
||||
|
||||
return gradgrad_out, None, None, None
|
||||
|
||||
|
@ -139,7 +138,8 @@ def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
|
|||
if not input.is_cuda:
|
||||
return bias_leakyrelu_ref(input, bias, negative_slope, scale)
|
||||
|
||||
return FusedBiasLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
||||
return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype),
|
||||
negative_slope, scale)
|
||||
|
||||
|
||||
def bias_leakyrelu_ref(x, bias, negative_slope=0.2, scale=2**0.5):
|
||||
|
|
Loading…
Reference in New Issue