mirror of https://github.com/open-mmlab/mmcv.git
[Fix] cast the type of mask to enable training with amp (#2220)
parent
a08517790d
commit
505b5771eb
|
@ -68,6 +68,7 @@ class ModulatedDeformConv2dFunction(Function):
|
|||
input = input.type_as(offset)
|
||||
weight = weight.type_as(input)
|
||||
bias = bias.type_as(input) # type: ignore
|
||||
mask = mask.type_as(input)
|
||||
ctx.save_for_backward(input, offset, mask, weight, bias)
|
||||
output = input.new_empty(
|
||||
ModulatedDeformConv2dFunction._output_size(ctx, input, weight))
|
||||
|
|
Loading…
Reference in New Issue