[Fix] inconsistent dtype ofseg_label in stdc decode (#1463)

This commit is contained in:
Miao Zheng 2022-04-10 14:18:16 +08:00 committed by GitHub
parent cba10b3f15
commit a09df2c39d

View File

@ -37,7 +37,7 @@ class STDCHead(FCNHead):
# parameters. However, it is a constant in original repo and other
# codebase because it would not be added into computation graph
# after threshold operation.
seg_label = seg_label.float()
seg_label = seg_label.to(self.laplacian_kernel)
boundary_targets = F.conv2d(
seg_label, self.laplacian_kernel, padding=1)
boundary_targets = boundary_targets.clamp(min=0)