mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Fix the bug when input mask is not '0-1-tensor' in masked_conv (#2423)
parent
6e3827d02b
commit
72ca0f30dc
|
@ -58,6 +58,8 @@ class MaskedConv2dFunction(Function):
|
|||
if mask.size()[1:] != output.size()[2:]:
|
||||
raise ValueError(
|
||||
'The mask is inconsistent with the shape of output_conv.')
|
||||
mask = mask > 0
|
||||
mask = mask.type(output.dtype)
|
||||
output = output * mask
|
||||
return output
|
||||
|
||||
|
|
Loading…
Reference in New Issue