mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
parent
f61295d944
commit
479624672a
@ -70,6 +70,9 @@ class DeformConv2dFunction(Function):
|
||||
ctx.deform_groups = deform_groups
|
||||
ctx.im2col_step = im2col_step
|
||||
|
||||
# until the code is modified for torch.cuda.amp.autocast,
|
||||
# we need to cast weight to avoid type mismatch in fp16 training
|
||||
weight = weight.type_as(input)
|
||||
ctx.save_for_backward(input, offset, weight)
|
||||
|
||||
output = input.new_empty(
|
||||
|
Loading…
x
Reference in New Issue
Block a user