[Fix] cast the type of mask to enable training with amp (#2220)

pull/2243/head
Yanhong Zeng 2022-08-24 20:28:32 +08:00 committed by Zaida Zhou
parent a08517790d
commit 505b5771eb
1 changed files with 1 additions and 0 deletions

View File

@ -68,6 +68,7 @@ class ModulatedDeformConv2dFunction(Function):
input = input.type_as(offset)
weight = weight.type_as(input)
bias = bias.type_as(input) # type: ignore
mask = mask.type_as(input)
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(
ModulatedDeformConv2dFunction._output_size(ctx, input, weight))