mirror of https://github.com/open-mmlab/mmcv.git
parent
f61295d944
commit
479624672a
|
@ -70,6 +70,9 @@ class DeformConv2dFunction(Function):
|
|||
ctx.deform_groups = deform_groups
|
||||
ctx.im2col_step = im2col_step
|
||||
|
||||
# until the code is modified for torch.cuda.amp.autocast,
|
||||
# we need to cast weight to avoid type mismatch in fp16 training
|
||||
weight = weight.type_as(input)
|
||||
ctx.save_for_backward(input, offset, weight)
|
||||
|
||||
output = input.new_empty(
|
||||
|
|
Loading…
Reference in New Issue