diff --git a/mmseg/models/decode_heads/decode_head.py b/mmseg/models/decode_heads/decode_head.py index 8ffc683bf..f6b05dd3e 100644 --- a/mmseg/models/decode_heads/decode_head.py +++ b/mmseg/models/decode_heads/decode_head.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings from abc import ABCMeta, abstractmethod import torch @@ -18,6 +19,9 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta): in_channels (int|Sequence[int]): Input channels. channels (int): Channels after modules, before conv_seg. num_classes (int): Number of classes. + out_channels (int): Output channels of conv_seg. + threshold (float): Threshold for binary segmentation in the case of + `num_classes==1`. Default: None. dropout_ratio (float): Ratio of dropout layer. Default: 0.1. conv_cfg (dict|None): Config of conv layers. Default: None. norm_cfg (dict|None): Config of norm layers. Default: None. @@ -56,6 +60,8 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta): channels, *, num_classes, + out_channels=None, + threshold=None, dropout_ratio=0.1, conv_cfg=None, norm_cfg=None, @@ -74,7 +80,6 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta): super(BaseDecodeHead, self).__init__(init_cfg) self._init_inputs(in_channels, in_index, input_transform) self.channels = channels - self.num_classes = num_classes self.dropout_ratio = dropout_ratio self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg @@ -84,6 +89,30 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta): self.ignore_index = ignore_index self.align_corners = align_corners + if out_channels is None: + if num_classes == 2: + warnings.warn('For binary segmentation, we suggest using' + '`out_channels = 1` to define the output' + 'channels of segmentor, and use `threshold`' + 'to convert seg_logist into a prediction' + 'applying a threshold') + out_channels = num_classes + + if out_channels != num_classes and out_channels != 1: + raise ValueError( + 'out_channels should be equal to num_classes,' + 'except binary segmentation set out_channels == 1 and' + f'num_classes == 2, but got out_channels={out_channels}' + f'and num_classes={num_classes}') + + if out_channels == 1 and threshold is None: + threshold = 0.3 + warnings.warn('threshold is not defined for binary, and defaults' + 'to 0.3') + self.num_classes = num_classes + self.out_channels = out_channels + self.threshold = threshold + if isinstance(loss_decode, dict): self.loss_decode = build_loss(loss_decode) elif isinstance(loss_decode, (list, tuple)): @@ -99,7 +128,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta): else: self.sampler = None - self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) + self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1) if dropout_ratio > 0: self.dropout = nn.Dropout2d(dropout_ratio) else: diff --git a/mmseg/models/segmentors/cascade_encoder_decoder.py b/mmseg/models/segmentors/cascade_encoder_decoder.py index 1913a22e2..e9a9127a6 100644 --- a/mmseg/models/segmentors/cascade_encoder_decoder.py +++ b/mmseg/models/segmentors/cascade_encoder_decoder.py @@ -47,6 +47,7 @@ class CascadeEncoderDecoder(EncoderDecoder): self.decode_head.append(builder.build_head(decode_head[i])) self.align_corners = self.decode_head[-1].align_corners self.num_classes = self.decode_head[-1].num_classes + self.out_channels = self.decode_head[-1].out_channels def encode_decode(self, img, img_metas): """Encode images with backbone and decode into a semantic segmentation diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index d94a3739e..678ae2b76 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -49,6 +49,7 @@ class EncoderDecoder(BaseSegmentor): self.decode_head = builder.build_head(decode_head) self.align_corners = self.decode_head.align_corners self.num_classes = self.decode_head.num_classes + self.out_channels = self.decode_head.out_channels def _init_auxiliary_head(self, auxiliary_head): """Initialize ``auxiliary_head``""" @@ -162,10 +163,10 @@ class EncoderDecoder(BaseSegmentor): h_stride, w_stride = self.test_cfg.stride h_crop, w_crop = self.test_cfg.crop_size batch_size, _, h_img, w_img = img.size() - num_classes = self.num_classes + out_channels = self.out_channels h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 - preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) + preds = img.new_zeros((batch_size, out_channels, h_img, w_img)) count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) for h_idx in range(h_grids): for w_idx in range(w_grids): @@ -245,7 +246,10 @@ class EncoderDecoder(BaseSegmentor): seg_logit = self.slide_inference(img, img_meta, rescale) else: seg_logit = self.whole_inference(img, img_meta, rescale) - output = F.softmax(seg_logit, dim=1) + if self.out_channels == 1: + output = F.sigmoid(seg_logit) + else: + output = F.softmax(seg_logit, dim=1) flip = img_meta[0]['flip'] if flip: flip_direction = img_meta[0]['flip_direction'] @@ -260,7 +264,11 @@ class EncoderDecoder(BaseSegmentor): def simple_test(self, img, img_meta, rescale=True): """Simple test with single image.""" seg_logit = self.inference(img, img_meta, rescale) - seg_pred = seg_logit.argmax(dim=1) + if self.out_channels == 1: + seg_pred = (seg_logit > + self.decode_head.threshold).to(seg_logit).squeeze(1) + else: + seg_pred = seg_logit.argmax(dim=1) if torch.onnx.is_in_onnx_export(): # our inference backend only support 4D output seg_pred = seg_pred.unsqueeze(0) @@ -283,7 +291,11 @@ class EncoderDecoder(BaseSegmentor): cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) seg_logit += cur_seg_logit seg_logit /= len(imgs) - seg_pred = seg_logit.argmax(dim=1) + if self.out_channels == 1: + seg_pred = (seg_logit > + self.decode_head.threshold).to(seg_logit).squeeze(1) + else: + seg_pred = seg_logit.argmax(dim=1) seg_pred = seg_pred.cpu().numpy() # unravel batch dim seg_pred = list(seg_pred) diff --git a/tests/test_models/test_heads/test_decode_head.py b/tests/test_models/test_heads/test_decode_head.py index cb9ab9718..87cadbcf8 100644 --- a/tests/test_models/test_heads/test_decode_head.py +++ b/tests/test_models/test_heads/test_decode_head.py @@ -43,6 +43,18 @@ def test_decode_head(): in_index=[-1], input_transform='resize_concat') + with pytest.raises(ValueError): + # out_channels should be equal to num_classes + BaseDecodeHead(32, 16, num_classes=19, out_channels=18) + + # test out_channels + head = BaseDecodeHead(32, 16, num_classes=2) + assert head.out_channels == 2 + + # test out_channels == 1 and num_classes == 2 + head = BaseDecodeHead(32, 16, num_classes=2, out_channels=1) + assert head.out_channels == 1 and head.num_classes == 2 + # test default dropout head = BaseDecodeHead(32, 16, num_classes=19) assert hasattr(head, 'dropout') and head.dropout.p == 0.1 diff --git a/tests/test_models/test_segmentors/test_encoder_decoder.py b/tests/test_models/test_segmentors/test_encoder_decoder.py index 4ed143727..2739b5867 100644 --- a/tests/test_models/test_segmentors/test_encoder_decoder.py +++ b/tests/test_models/test_segmentors/test_encoder_decoder.py @@ -18,6 +18,12 @@ def test_encoder_decoder(): segmentor = build_segmentor(cfg) _segmentor_forward_train_test(segmentor) + # test out_channels == 1 + segmentor.out_channels = 1 + segmentor.decode_head.out_channels = 1 + segmentor.decode_head.threshold = 0.3 + _segmentor_forward_train_test(segmentor) + # test slide mode cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2)) segmentor = build_segmentor(cfg)