[Fix] Fix the bug in binary_cross_entropy (#1527)
* [Fix] Fix the bug in binary_cross_entropy Fix the bug in binary_cross_entropy 'label.max() <= 1' should mask out ignore_index, since the ignore_index often set as 255. * [Fix] Fix the bug in binary_cross_entropy, add comments As the ignore_index often set as 255, so the binary class label check should mask out ignore_index. Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> * [Fix] Fix the bug in binary_cross_entropy As the ignore_index often set as 255, so the binary class label check should mask out ignore_index. Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Co-authored-by: MeowZheng <meowzheng@outlook.com>pull/1801/head
parent
d0955901b6
commit
578d4d0c42
|
@ -118,7 +118,10 @@ def binary_cross_entropy(pred,
|
|||
if pred.size(1) == 1:
|
||||
# For binary class segmentation, the shape of pred is
|
||||
# [N, 1, H, W] and that of label is [N, H, W].
|
||||
assert label.max() <= 1, \
|
||||
# As the ignore_index often set as 255, so the
|
||||
# binary class label check should mask out
|
||||
# ignore_index
|
||||
assert label[label != ignore_index].max() <= 1, \
|
||||
'For pred with shape [N, 1, H, W], its label must have at ' \
|
||||
'most 2 classes'
|
||||
pred = pred.squeeze()
|
||||
|
|
Loading…
Reference in New Issue