[Fix] Fix the error of BCE loss when batch size is 1. (#1629)
parent
ef905223b3
commit
63fa98515b
|
@ -124,7 +124,7 @@ def binary_cross_entropy(pred,
|
|||
assert label[label != ignore_index].max() <= 1, \
|
||||
'For pred with shape [N, 1, H, W], its label must have at ' \
|
||||
'most 2 classes'
|
||||
pred = pred.squeeze()
|
||||
pred = pred.squeeze(1)
|
||||
if pred.dim() != label.dim():
|
||||
assert (pred.dim() == 2 and label.dim() == 1) or (
|
||||
pred.dim() == 4 and label.dim() == 3), \
|
||||
|
|
Loading…
Reference in New Issue