diff --git a/mmseg/models/decode_heads/stdc_head.py b/mmseg/models/decode_heads/stdc_head.py index 1e678ace1..1cf3732ce 100644 --- a/mmseg/models/decode_heads/stdc_head.py +++ b/mmseg/models/decode_heads/stdc_head.py @@ -37,7 +37,7 @@ class STDCHead(FCNHead): # parameters. However, it is a constant in original repo and other # codebase because it would not be added into computation graph # after threshold operation. - seg_label = seg_label.float() + seg_label = seg_label.to(self.laplacian_kernel) boundary_targets = F.conv2d( seg_label, self.laplacian_kernel, padding=1) boundary_targets = boundary_targets.clamp(min=0)