mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Fix binary segmentation when num_classes==1
(#2016)
* fix binary * add ut * fix ut * restore metric computation * remove metric ut update * set out_channels by num_classes * replace num_classes in encoder_decoder * update props setting and fix ut * update ut * minor change * update warning
This commit is contained in:
parent
d8ea8f7460
commit
c1c942e8fc
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import torch
|
||||
@ -18,6 +19,9 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
num_classes (int): Number of classes.
|
||||
out_channels (int): Output channels of conv_seg.
|
||||
threshold (float): Threshold for binary segmentation in the case of
|
||||
`num_classes==1`. Default: None.
|
||||
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
||||
conv_cfg (dict|None): Config of conv layers. Default: None.
|
||||
norm_cfg (dict|None): Config of norm layers. Default: None.
|
||||
@ -56,6 +60,8 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||
channels,
|
||||
*,
|
||||
num_classes,
|
||||
out_channels=None,
|
||||
threshold=None,
|
||||
dropout_ratio=0.1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
@ -74,7 +80,6 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||
super(BaseDecodeHead, self).__init__(init_cfg)
|
||||
self._init_inputs(in_channels, in_index, input_transform)
|
||||
self.channels = channels
|
||||
self.num_classes = num_classes
|
||||
self.dropout_ratio = dropout_ratio
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
@ -84,6 +89,30 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||
self.ignore_index = ignore_index
|
||||
self.align_corners = align_corners
|
||||
|
||||
if out_channels is None:
|
||||
if num_classes == 2:
|
||||
warnings.warn('For binary segmentation, we suggest using'
|
||||
'`out_channels = 1` to define the output'
|
||||
'channels of segmentor, and use `threshold`'
|
||||
'to convert seg_logist into a prediction'
|
||||
'applying a threshold')
|
||||
out_channels = num_classes
|
||||
|
||||
if out_channels != num_classes and out_channels != 1:
|
||||
raise ValueError(
|
||||
'out_channels should be equal to num_classes,'
|
||||
'except binary segmentation set out_channels == 1 and'
|
||||
f'num_classes == 2, but got out_channels={out_channels}'
|
||||
f'and num_classes={num_classes}')
|
||||
|
||||
if out_channels == 1 and threshold is None:
|
||||
threshold = 0.3
|
||||
warnings.warn('threshold is not defined for binary, and defaults'
|
||||
'to 0.3')
|
||||
self.num_classes = num_classes
|
||||
self.out_channels = out_channels
|
||||
self.threshold = threshold
|
||||
|
||||
if isinstance(loss_decode, dict):
|
||||
self.loss_decode = build_loss(loss_decode)
|
||||
elif isinstance(loss_decode, (list, tuple)):
|
||||
@ -99,7 +128,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||
else:
|
||||
self.sampler = None
|
||||
|
||||
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
||||
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
|
||||
if dropout_ratio > 0:
|
||||
self.dropout = nn.Dropout2d(dropout_ratio)
|
||||
else:
|
||||
|
@ -47,6 +47,7 @@ class CascadeEncoderDecoder(EncoderDecoder):
|
||||
self.decode_head.append(builder.build_head(decode_head[i]))
|
||||
self.align_corners = self.decode_head[-1].align_corners
|
||||
self.num_classes = self.decode_head[-1].num_classes
|
||||
self.out_channels = self.decode_head[-1].out_channels
|
||||
|
||||
def encode_decode(self, img, img_metas):
|
||||
"""Encode images with backbone and decode into a semantic segmentation
|
||||
|
@ -49,6 +49,7 @@ class EncoderDecoder(BaseSegmentor):
|
||||
self.decode_head = builder.build_head(decode_head)
|
||||
self.align_corners = self.decode_head.align_corners
|
||||
self.num_classes = self.decode_head.num_classes
|
||||
self.out_channels = self.decode_head.out_channels
|
||||
|
||||
def _init_auxiliary_head(self, auxiliary_head):
|
||||
"""Initialize ``auxiliary_head``"""
|
||||
@ -162,10 +163,10 @@ class EncoderDecoder(BaseSegmentor):
|
||||
h_stride, w_stride = self.test_cfg.stride
|
||||
h_crop, w_crop = self.test_cfg.crop_size
|
||||
batch_size, _, h_img, w_img = img.size()
|
||||
num_classes = self.num_classes
|
||||
out_channels = self.out_channels
|
||||
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
||||
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
||||
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
|
||||
preds = img.new_zeros((batch_size, out_channels, h_img, w_img))
|
||||
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
|
||||
for h_idx in range(h_grids):
|
||||
for w_idx in range(w_grids):
|
||||
@ -245,7 +246,10 @@ class EncoderDecoder(BaseSegmentor):
|
||||
seg_logit = self.slide_inference(img, img_meta, rescale)
|
||||
else:
|
||||
seg_logit = self.whole_inference(img, img_meta, rescale)
|
||||
output = F.softmax(seg_logit, dim=1)
|
||||
if self.out_channels == 1:
|
||||
output = F.sigmoid(seg_logit)
|
||||
else:
|
||||
output = F.softmax(seg_logit, dim=1)
|
||||
flip = img_meta[0]['flip']
|
||||
if flip:
|
||||
flip_direction = img_meta[0]['flip_direction']
|
||||
@ -260,7 +264,11 @@ class EncoderDecoder(BaseSegmentor):
|
||||
def simple_test(self, img, img_meta, rescale=True):
|
||||
"""Simple test with single image."""
|
||||
seg_logit = self.inference(img, img_meta, rescale)
|
||||
seg_pred = seg_logit.argmax(dim=1)
|
||||
if self.out_channels == 1:
|
||||
seg_pred = (seg_logit >
|
||||
self.decode_head.threshold).to(seg_logit).squeeze(1)
|
||||
else:
|
||||
seg_pred = seg_logit.argmax(dim=1)
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
# our inference backend only support 4D output
|
||||
seg_pred = seg_pred.unsqueeze(0)
|
||||
@ -283,7 +291,11 @@ class EncoderDecoder(BaseSegmentor):
|
||||
cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
|
||||
seg_logit += cur_seg_logit
|
||||
seg_logit /= len(imgs)
|
||||
seg_pred = seg_logit.argmax(dim=1)
|
||||
if self.out_channels == 1:
|
||||
seg_pred = (seg_logit >
|
||||
self.decode_head.threshold).to(seg_logit).squeeze(1)
|
||||
else:
|
||||
seg_pred = seg_logit.argmax(dim=1)
|
||||
seg_pred = seg_pred.cpu().numpy()
|
||||
# unravel batch dim
|
||||
seg_pred = list(seg_pred)
|
||||
|
@ -43,6 +43,18 @@ def test_decode_head():
|
||||
in_index=[-1],
|
||||
input_transform='resize_concat')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# out_channels should be equal to num_classes
|
||||
BaseDecodeHead(32, 16, num_classes=19, out_channels=18)
|
||||
|
||||
# test out_channels
|
||||
head = BaseDecodeHead(32, 16, num_classes=2)
|
||||
assert head.out_channels == 2
|
||||
|
||||
# test out_channels == 1 and num_classes == 2
|
||||
head = BaseDecodeHead(32, 16, num_classes=2, out_channels=1)
|
||||
assert head.out_channels == 1 and head.num_classes == 2
|
||||
|
||||
# test default dropout
|
||||
head = BaseDecodeHead(32, 16, num_classes=19)
|
||||
assert hasattr(head, 'dropout') and head.dropout.p == 0.1
|
||||
|
@ -18,6 +18,12 @@ def test_encoder_decoder():
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
|
||||
# test out_channels == 1
|
||||
segmentor.out_channels = 1
|
||||
segmentor.decode_head.out_channels = 1
|
||||
segmentor.decode_head.threshold = 0.3
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
|
||||
# test slide mode
|
||||
cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2))
|
||||
segmentor = build_segmentor(cfg)
|
||||
|
Loading…
x
Reference in New Issue
Block a user