mirror of https://github.com/open-mmlab/mmcv.git
Prevent kernal_normalizer to change mask dtype (#1210)
parent
17fa6670eb
commit
0a375614ca
|
@ -268,7 +268,7 @@ class CARAFEPack(nn.Module):
|
|||
mask_channel = int(mask_c / float(self.up_kernel**2))
|
||||
mask = mask.view(n, mask_channel, -1, h, w)
|
||||
|
||||
mask = F.softmax(mask, dim=2)
|
||||
mask = F.softmax(mask, dim=2, dtype=mask.dtype)
|
||||
mask = mask.view(n, mask_c, h, w).contiguous()
|
||||
|
||||
return mask
|
||||
|
|
Loading…
Reference in New Issue