Prevent kernal_normalizer to change mask dtype (#1210)

pull/1214/head
ddonatien 2021-07-23 21:08:58 +08:00 committed by GitHub
parent 17fa6670eb
commit 0a375614ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -268,7 +268,7 @@ class CARAFEPack(nn.Module):
mask_channel = int(mask_c / float(self.up_kernel**2))
mask = mask.view(n, mask_channel, -1, h, w)
mask = F.softmax(mask, dim=2)
mask = F.softmax(mask, dim=2, dtype=mask.dtype)
mask = mask.view(n, mask_c, h, w).contiguous()
return mask