From 4371ba5db60bee5e0544aab16c176ae72e85d0f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=BC=80=E5=AE=87?= <1150249769@qq.com> Date: Tue, 18 Apr 2023 18:26:59 +0800 Subject: [PATCH] [Fix] Fix bugs when out_channels==1 (#2911) --- mmseg/models/segmentors/base.py | 1 + mmseg/models/segmentors/encoder_decoder.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 25487de5a..17a0bb2b3 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -187,6 +187,7 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta): if C > 1: i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True) else: + i_seg_logits = i_seg_logits.sigmoid() i_seg_pred = (i_seg_logits > self.decode_head.threshold).to(i_seg_logits) data_samples[i].set_data({ diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index 0a8db3ec7..370d0305f 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -260,10 +260,10 @@ class EncoderDecoder(BaseSegmentor): h_stride, w_stride = self.test_cfg.stride h_crop, w_crop = self.test_cfg.crop_size batch_size, _, h_img, w_img = inputs.size() - num_classes = self.num_classes + out_channels = self.out_channels h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 - preds = inputs.new_zeros((batch_size, num_classes, h_img, w_img)) + preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img)) for h_idx in range(h_grids): for w_idx in range(w_grids):