diff --git a/mmcv/ops/masked_conv.py b/mmcv/ops/masked_conv.py index e00a98b99..275b503a4 100644 --- a/mmcv/ops/masked_conv.py +++ b/mmcv/ops/masked_conv.py @@ -58,6 +58,8 @@ class MaskedConv2dFunction(Function): if mask.size()[1:] != output.size()[2:]: raise ValueError( 'The mask is inconsistent with the shape of output_conv.') + mask = mask > 0 + mask = mask.type(output.dtype) output = output * mask return output