[Fix] Fix binary segmentation when num_classes==1 (#2016)

* fix binary

* add ut

* fix ut

* restore metric computation

* remove metric ut update

* set out_channels by num_classes

* replace num_classes in encoder_decoder

* update props setting and fix ut

* update ut

* minor change

* update warning
This commit is contained in:
谢昕辰 2022-09-08 14:43:21 +08:00 committed by GitHub
parent d8ea8f7460
commit c1c942e8fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 67 additions and 7 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from abc import ABCMeta, abstractmethod
import torch
@ -18,6 +19,9 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
in_channels (int|Sequence[int]): Input channels.
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
out_channels (int): Output channels of conv_seg.
threshold (float): Threshold for binary segmentation in the case of
`num_classes==1`. Default: None.
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
conv_cfg (dict|None): Config of conv layers. Default: None.
norm_cfg (dict|None): Config of norm layers. Default: None.
@ -56,6 +60,8 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
channels,
*,
num_classes,
out_channels=None,
threshold=None,
dropout_ratio=0.1,
conv_cfg=None,
norm_cfg=None,
@ -74,7 +80,6 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
super(BaseDecodeHead, self).__init__(init_cfg)
self._init_inputs(in_channels, in_index, input_transform)
self.channels = channels
self.num_classes = num_classes
self.dropout_ratio = dropout_ratio
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
@ -84,6 +89,30 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
self.ignore_index = ignore_index
self.align_corners = align_corners
if out_channels is None:
if num_classes == 2:
warnings.warn('For binary segmentation, we suggest using'
'`out_channels = 1` to define the output'
'channels of segmentor, and use `threshold`'
'to convert seg_logist into a prediction'
'applying a threshold')
out_channels = num_classes
if out_channels != num_classes and out_channels != 1:
raise ValueError(
'out_channels should be equal to num_classes,'
'except binary segmentation set out_channels == 1 and'
f'num_classes == 2, but got out_channels={out_channels}'
f'and num_classes={num_classes}')
if out_channels == 1 and threshold is None:
threshold = 0.3
warnings.warn('threshold is not defined for binary, and defaults'
'to 0.3')
self.num_classes = num_classes
self.out_channels = out_channels
self.threshold = threshold
if isinstance(loss_decode, dict):
self.loss_decode = build_loss(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
@ -99,7 +128,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
else:
self.sampler = None
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
if dropout_ratio > 0:
self.dropout = nn.Dropout2d(dropout_ratio)
else:

View File

@ -47,6 +47,7 @@ class CascadeEncoderDecoder(EncoderDecoder):
self.decode_head.append(builder.build_head(decode_head[i]))
self.align_corners = self.decode_head[-1].align_corners
self.num_classes = self.decode_head[-1].num_classes
self.out_channels = self.decode_head[-1].out_channels
def encode_decode(self, img, img_metas):
"""Encode images with backbone and decode into a semantic segmentation

View File

@ -49,6 +49,7 @@ class EncoderDecoder(BaseSegmentor):
self.decode_head = builder.build_head(decode_head)
self.align_corners = self.decode_head.align_corners
self.num_classes = self.decode_head.num_classes
self.out_channels = self.decode_head.out_channels
def _init_auxiliary_head(self, auxiliary_head):
"""Initialize ``auxiliary_head``"""
@ -162,10 +163,10 @@ class EncoderDecoder(BaseSegmentor):
h_stride, w_stride = self.test_cfg.stride
h_crop, w_crop = self.test_cfg.crop_size
batch_size, _, h_img, w_img = img.size()
num_classes = self.num_classes
out_channels = self.out_channels
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
preds = img.new_zeros((batch_size, out_channels, h_img, w_img))
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):
@ -245,7 +246,10 @@ class EncoderDecoder(BaseSegmentor):
seg_logit = self.slide_inference(img, img_meta, rescale)
else:
seg_logit = self.whole_inference(img, img_meta, rescale)
output = F.softmax(seg_logit, dim=1)
if self.out_channels == 1:
output = F.sigmoid(seg_logit)
else:
output = F.softmax(seg_logit, dim=1)
flip = img_meta[0]['flip']
if flip:
flip_direction = img_meta[0]['flip_direction']
@ -260,7 +264,11 @@ class EncoderDecoder(BaseSegmentor):
def simple_test(self, img, img_meta, rescale=True):
"""Simple test with single image."""
seg_logit = self.inference(img, img_meta, rescale)
seg_pred = seg_logit.argmax(dim=1)
if self.out_channels == 1:
seg_pred = (seg_logit >
self.decode_head.threshold).to(seg_logit).squeeze(1)
else:
seg_pred = seg_logit.argmax(dim=1)
if torch.onnx.is_in_onnx_export():
# our inference backend only support 4D output
seg_pred = seg_pred.unsqueeze(0)
@ -283,7 +291,11 @@ class EncoderDecoder(BaseSegmentor):
cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
seg_logit += cur_seg_logit
seg_logit /= len(imgs)
seg_pred = seg_logit.argmax(dim=1)
if self.out_channels == 1:
seg_pred = (seg_logit >
self.decode_head.threshold).to(seg_logit).squeeze(1)
else:
seg_pred = seg_logit.argmax(dim=1)
seg_pred = seg_pred.cpu().numpy()
# unravel batch dim
seg_pred = list(seg_pred)

View File

@ -43,6 +43,18 @@ def test_decode_head():
in_index=[-1],
input_transform='resize_concat')
with pytest.raises(ValueError):
# out_channels should be equal to num_classes
BaseDecodeHead(32, 16, num_classes=19, out_channels=18)
# test out_channels
head = BaseDecodeHead(32, 16, num_classes=2)
assert head.out_channels == 2
# test out_channels == 1 and num_classes == 2
head = BaseDecodeHead(32, 16, num_classes=2, out_channels=1)
assert head.out_channels == 1 and head.num_classes == 2
# test default dropout
head = BaseDecodeHead(32, 16, num_classes=19)
assert hasattr(head, 'dropout') and head.dropout.p == 0.1

View File

@ -18,6 +18,12 @@ def test_encoder_decoder():
segmentor = build_segmentor(cfg)
_segmentor_forward_train_test(segmentor)
# test out_channels == 1
segmentor.out_channels = 1
segmentor.decode_head.out_channels = 1
segmentor.decode_head.threshold = 0.3
_segmentor_forward_train_test(segmentor)
# test slide mode
cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2))
segmentor = build_segmentor(cfg)