[Fix] Fix DCN fp16 (#1014)

* [Fix] Fix DCN fp16

* add comment
pull/1018/head
Yosuke Shinya 2021-05-11 14:16:23 +09:00 committed by GitHub
parent f61295d944
commit 479624672a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 0 deletions

View File

@ -70,6 +70,9 @@ class DeformConv2dFunction(Function):
ctx.deform_groups = deform_groups
ctx.im2col_step = im2col_step
# until the code is modified for torch.cuda.amp.autocast,
# we need to cast weight to avoid type mismatch in fp16 training
weight = weight.type_as(input)
ctx.save_for_backward(input, offset, weight)
output = input.new_empty(