From 63fa98515bfec26c9261f6baa0adcf8e7fccfe2f Mon Sep 17 00:00:00 2001 From: Wencheng Wu <41542251+274869388@users.noreply.github.com> Date: Tue, 31 May 2022 18:07:29 +0800 Subject: [PATCH] [Fix] Fix the error of BCE loss when batch size is 1. (#1629) --- mmseg/models/losses/cross_entropy_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmseg/models/losses/cross_entropy_loss.py b/mmseg/models/losses/cross_entropy_loss.py index 623fd58db..fe7b4a262 100644 --- a/mmseg/models/losses/cross_entropy_loss.py +++ b/mmseg/models/losses/cross_entropy_loss.py @@ -124,7 +124,7 @@ def binary_cross_entropy(pred, assert label[label != ignore_index].max() <= 1, \ 'For pred with shape [N, 1, H, W], its label must have at ' \ 'most 2 classes' - pred = pred.squeeze() + pred = pred.squeeze(1) if pred.dim() != label.dim(): assert (pred.dim() == 2 and label.dim() == 1) or ( pred.dim() == 4 and label.dim() == 3), \