[Fix] Fix the error of BCE loss when batch size is 1. (#1629)

pull/1642/head
Wencheng Wu 2022-05-31 18:07:29 +08:00 committed by GitHub
parent ef905223b3
commit 63fa98515b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -124,7 +124,7 @@ def binary_cross_entropy(pred,
assert label[label != ignore_index].max() <= 1, \
'For pred with shape [N, 1, H, W], its label must have at ' \
'most 2 classes'
pred = pred.squeeze()
pred = pred.squeeze(1)
if pred.dim() != label.dim():
assert (pred.dim() == 2 and label.dim() == 1) or (
pred.dim() == 4 and label.dim() == 3), \