mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Fix binary segmentation when num_classes==1
(#2016)
* fix binary * add ut * fix ut * restore metric computation * remove metric ut update * set out_channels by num_classes * replace num_classes in encoder_decoder * update props setting and fix ut * update ut * minor change * update warning
This commit is contained in:
parent
d8ea8f7460
commit
c1c942e8fc
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import warnings
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -18,6 +19,9 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||||||
in_channels (int|Sequence[int]): Input channels.
|
in_channels (int|Sequence[int]): Input channels.
|
||||||
channels (int): Channels after modules, before conv_seg.
|
channels (int): Channels after modules, before conv_seg.
|
||||||
num_classes (int): Number of classes.
|
num_classes (int): Number of classes.
|
||||||
|
out_channels (int): Output channels of conv_seg.
|
||||||
|
threshold (float): Threshold for binary segmentation in the case of
|
||||||
|
`num_classes==1`. Default: None.
|
||||||
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
||||||
conv_cfg (dict|None): Config of conv layers. Default: None.
|
conv_cfg (dict|None): Config of conv layers. Default: None.
|
||||||
norm_cfg (dict|None): Config of norm layers. Default: None.
|
norm_cfg (dict|None): Config of norm layers. Default: None.
|
||||||
@ -56,6 +60,8 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||||||
channels,
|
channels,
|
||||||
*,
|
*,
|
||||||
num_classes,
|
num_classes,
|
||||||
|
out_channels=None,
|
||||||
|
threshold=None,
|
||||||
dropout_ratio=0.1,
|
dropout_ratio=0.1,
|
||||||
conv_cfg=None,
|
conv_cfg=None,
|
||||||
norm_cfg=None,
|
norm_cfg=None,
|
||||||
@ -74,7 +80,6 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||||||
super(BaseDecodeHead, self).__init__(init_cfg)
|
super(BaseDecodeHead, self).__init__(init_cfg)
|
||||||
self._init_inputs(in_channels, in_index, input_transform)
|
self._init_inputs(in_channels, in_index, input_transform)
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.num_classes = num_classes
|
|
||||||
self.dropout_ratio = dropout_ratio
|
self.dropout_ratio = dropout_ratio
|
||||||
self.conv_cfg = conv_cfg
|
self.conv_cfg = conv_cfg
|
||||||
self.norm_cfg = norm_cfg
|
self.norm_cfg = norm_cfg
|
||||||
@ -84,6 +89,30 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
self.align_corners = align_corners
|
self.align_corners = align_corners
|
||||||
|
|
||||||
|
if out_channels is None:
|
||||||
|
if num_classes == 2:
|
||||||
|
warnings.warn('For binary segmentation, we suggest using'
|
||||||
|
'`out_channels = 1` to define the output'
|
||||||
|
'channels of segmentor, and use `threshold`'
|
||||||
|
'to convert seg_logist into a prediction'
|
||||||
|
'applying a threshold')
|
||||||
|
out_channels = num_classes
|
||||||
|
|
||||||
|
if out_channels != num_classes and out_channels != 1:
|
||||||
|
raise ValueError(
|
||||||
|
'out_channels should be equal to num_classes,'
|
||||||
|
'except binary segmentation set out_channels == 1 and'
|
||||||
|
f'num_classes == 2, but got out_channels={out_channels}'
|
||||||
|
f'and num_classes={num_classes}')
|
||||||
|
|
||||||
|
if out_channels == 1 and threshold is None:
|
||||||
|
threshold = 0.3
|
||||||
|
warnings.warn('threshold is not defined for binary, and defaults'
|
||||||
|
'to 0.3')
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
if isinstance(loss_decode, dict):
|
if isinstance(loss_decode, dict):
|
||||||
self.loss_decode = build_loss(loss_decode)
|
self.loss_decode = build_loss(loss_decode)
|
||||||
elif isinstance(loss_decode, (list, tuple)):
|
elif isinstance(loss_decode, (list, tuple)):
|
||||||
@ -99,7 +128,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||||||
else:
|
else:
|
||||||
self.sampler = None
|
self.sampler = None
|
||||||
|
|
||||||
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
|
||||||
if dropout_ratio > 0:
|
if dropout_ratio > 0:
|
||||||
self.dropout = nn.Dropout2d(dropout_ratio)
|
self.dropout = nn.Dropout2d(dropout_ratio)
|
||||||
else:
|
else:
|
||||||
|
@ -47,6 +47,7 @@ class CascadeEncoderDecoder(EncoderDecoder):
|
|||||||
self.decode_head.append(builder.build_head(decode_head[i]))
|
self.decode_head.append(builder.build_head(decode_head[i]))
|
||||||
self.align_corners = self.decode_head[-1].align_corners
|
self.align_corners = self.decode_head[-1].align_corners
|
||||||
self.num_classes = self.decode_head[-1].num_classes
|
self.num_classes = self.decode_head[-1].num_classes
|
||||||
|
self.out_channels = self.decode_head[-1].out_channels
|
||||||
|
|
||||||
def encode_decode(self, img, img_metas):
|
def encode_decode(self, img, img_metas):
|
||||||
"""Encode images with backbone and decode into a semantic segmentation
|
"""Encode images with backbone and decode into a semantic segmentation
|
||||||
|
@ -49,6 +49,7 @@ class EncoderDecoder(BaseSegmentor):
|
|||||||
self.decode_head = builder.build_head(decode_head)
|
self.decode_head = builder.build_head(decode_head)
|
||||||
self.align_corners = self.decode_head.align_corners
|
self.align_corners = self.decode_head.align_corners
|
||||||
self.num_classes = self.decode_head.num_classes
|
self.num_classes = self.decode_head.num_classes
|
||||||
|
self.out_channels = self.decode_head.out_channels
|
||||||
|
|
||||||
def _init_auxiliary_head(self, auxiliary_head):
|
def _init_auxiliary_head(self, auxiliary_head):
|
||||||
"""Initialize ``auxiliary_head``"""
|
"""Initialize ``auxiliary_head``"""
|
||||||
@ -162,10 +163,10 @@ class EncoderDecoder(BaseSegmentor):
|
|||||||
h_stride, w_stride = self.test_cfg.stride
|
h_stride, w_stride = self.test_cfg.stride
|
||||||
h_crop, w_crop = self.test_cfg.crop_size
|
h_crop, w_crop = self.test_cfg.crop_size
|
||||||
batch_size, _, h_img, w_img = img.size()
|
batch_size, _, h_img, w_img = img.size()
|
||||||
num_classes = self.num_classes
|
out_channels = self.out_channels
|
||||||
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
||||||
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
||||||
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
|
preds = img.new_zeros((batch_size, out_channels, h_img, w_img))
|
||||||
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
|
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
|
||||||
for h_idx in range(h_grids):
|
for h_idx in range(h_grids):
|
||||||
for w_idx in range(w_grids):
|
for w_idx in range(w_grids):
|
||||||
@ -245,7 +246,10 @@ class EncoderDecoder(BaseSegmentor):
|
|||||||
seg_logit = self.slide_inference(img, img_meta, rescale)
|
seg_logit = self.slide_inference(img, img_meta, rescale)
|
||||||
else:
|
else:
|
||||||
seg_logit = self.whole_inference(img, img_meta, rescale)
|
seg_logit = self.whole_inference(img, img_meta, rescale)
|
||||||
output = F.softmax(seg_logit, dim=1)
|
if self.out_channels == 1:
|
||||||
|
output = F.sigmoid(seg_logit)
|
||||||
|
else:
|
||||||
|
output = F.softmax(seg_logit, dim=1)
|
||||||
flip = img_meta[0]['flip']
|
flip = img_meta[0]['flip']
|
||||||
if flip:
|
if flip:
|
||||||
flip_direction = img_meta[0]['flip_direction']
|
flip_direction = img_meta[0]['flip_direction']
|
||||||
@ -260,7 +264,11 @@ class EncoderDecoder(BaseSegmentor):
|
|||||||
def simple_test(self, img, img_meta, rescale=True):
|
def simple_test(self, img, img_meta, rescale=True):
|
||||||
"""Simple test with single image."""
|
"""Simple test with single image."""
|
||||||
seg_logit = self.inference(img, img_meta, rescale)
|
seg_logit = self.inference(img, img_meta, rescale)
|
||||||
seg_pred = seg_logit.argmax(dim=1)
|
if self.out_channels == 1:
|
||||||
|
seg_pred = (seg_logit >
|
||||||
|
self.decode_head.threshold).to(seg_logit).squeeze(1)
|
||||||
|
else:
|
||||||
|
seg_pred = seg_logit.argmax(dim=1)
|
||||||
if torch.onnx.is_in_onnx_export():
|
if torch.onnx.is_in_onnx_export():
|
||||||
# our inference backend only support 4D output
|
# our inference backend only support 4D output
|
||||||
seg_pred = seg_pred.unsqueeze(0)
|
seg_pred = seg_pred.unsqueeze(0)
|
||||||
@ -283,7 +291,11 @@ class EncoderDecoder(BaseSegmentor):
|
|||||||
cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
|
cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
|
||||||
seg_logit += cur_seg_logit
|
seg_logit += cur_seg_logit
|
||||||
seg_logit /= len(imgs)
|
seg_logit /= len(imgs)
|
||||||
seg_pred = seg_logit.argmax(dim=1)
|
if self.out_channels == 1:
|
||||||
|
seg_pred = (seg_logit >
|
||||||
|
self.decode_head.threshold).to(seg_logit).squeeze(1)
|
||||||
|
else:
|
||||||
|
seg_pred = seg_logit.argmax(dim=1)
|
||||||
seg_pred = seg_pred.cpu().numpy()
|
seg_pred = seg_pred.cpu().numpy()
|
||||||
# unravel batch dim
|
# unravel batch dim
|
||||||
seg_pred = list(seg_pred)
|
seg_pred = list(seg_pred)
|
||||||
|
@ -43,6 +43,18 @@ def test_decode_head():
|
|||||||
in_index=[-1],
|
in_index=[-1],
|
||||||
input_transform='resize_concat')
|
input_transform='resize_concat')
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# out_channels should be equal to num_classes
|
||||||
|
BaseDecodeHead(32, 16, num_classes=19, out_channels=18)
|
||||||
|
|
||||||
|
# test out_channels
|
||||||
|
head = BaseDecodeHead(32, 16, num_classes=2)
|
||||||
|
assert head.out_channels == 2
|
||||||
|
|
||||||
|
# test out_channels == 1 and num_classes == 2
|
||||||
|
head = BaseDecodeHead(32, 16, num_classes=2, out_channels=1)
|
||||||
|
assert head.out_channels == 1 and head.num_classes == 2
|
||||||
|
|
||||||
# test default dropout
|
# test default dropout
|
||||||
head = BaseDecodeHead(32, 16, num_classes=19)
|
head = BaseDecodeHead(32, 16, num_classes=19)
|
||||||
assert hasattr(head, 'dropout') and head.dropout.p == 0.1
|
assert hasattr(head, 'dropout') and head.dropout.p == 0.1
|
||||||
|
@ -18,6 +18,12 @@ def test_encoder_decoder():
|
|||||||
segmentor = build_segmentor(cfg)
|
segmentor = build_segmentor(cfg)
|
||||||
_segmentor_forward_train_test(segmentor)
|
_segmentor_forward_train_test(segmentor)
|
||||||
|
|
||||||
|
# test out_channels == 1
|
||||||
|
segmentor.out_channels = 1
|
||||||
|
segmentor.decode_head.out_channels = 1
|
||||||
|
segmentor.decode_head.threshold = 0.3
|
||||||
|
_segmentor_forward_train_test(segmentor)
|
||||||
|
|
||||||
# test slide mode
|
# test slide mode
|
||||||
cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2))
|
cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2))
|
||||||
segmentor = build_segmentor(cfg)
|
segmentor = build_segmentor(cfg)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user