diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 25487de5a..17a0bb2b3 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -187,6 +187,7 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta): if C > 1: i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True) else: + i_seg_logits = i_seg_logits.sigmoid() i_seg_pred = (i_seg_logits > self.decode_head.threshold).to(i_seg_logits) data_samples[i].set_data({ diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index 0a8db3ec7..370d0305f 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -260,10 +260,10 @@ class EncoderDecoder(BaseSegmentor): h_stride, w_stride = self.test_cfg.stride h_crop, w_crop = self.test_cfg.crop_size batch_size, _, h_img, w_img = inputs.size() - num_classes = self.num_classes + out_channels = self.out_channels h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 - preds = inputs.new_zeros((batch_size, num_classes, h_img, w_img)) + preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img)) for h_idx in range(h_grids): for w_idx in range(w_grids):