mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Change self.loss_decode
back to dict
in Single Loss situation. (#1002)
* fix single loss type * fix error in ohem & point_head * fix coverage miss * fix uncoverage error of PointHead loss * fix coverage miss * fix uncoverage error of PointHead loss * nn.modules.container.ModuleList to nn.ModuleList * more simple format * merge unittest def
This commit is contained in:
parent
d7f82e5dc8
commit
992d577783
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ..builder import PIXEL_SAMPLERS
|
from ..builder import PIXEL_SAMPLERS
|
||||||
@ -62,14 +63,19 @@ class OHEMPixelSampler(BasePixelSampler):
|
|||||||
threshold = max(min_threshold, self.thresh)
|
threshold = max(min_threshold, self.thresh)
|
||||||
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
|
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
|
||||||
else:
|
else:
|
||||||
|
if not isinstance(self.context.loss_decode, nn.ModuleList):
|
||||||
|
losses_decode = [self.context.loss_decode]
|
||||||
|
else:
|
||||||
|
losses_decode = self.context.loss_decode
|
||||||
losses = 0.0
|
losses = 0.0
|
||||||
for loss_module in self.context.loss_decode:
|
for loss_module in losses_decode:
|
||||||
losses += loss_module(
|
losses += loss_module(
|
||||||
seg_logit,
|
seg_logit,
|
||||||
seg_label,
|
seg_label,
|
||||||
weight=None,
|
weight=None,
|
||||||
ignore_index=self.context.ignore_index,
|
ignore_index=self.context.ignore_index,
|
||||||
reduction_override='none')
|
reduction_override='none')
|
||||||
|
|
||||||
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
|
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
|
||||||
_, sort_indices = losses[valid_mask].sort(descending=True)
|
_, sort_indices = losses[valid_mask].sort(descending=True)
|
||||||
valid_seg_weight[sort_indices[:batch_kept]] = 1.
|
valid_seg_weight[sort_indices[:batch_kept]] = 1.
|
||||||
|
@ -83,11 +83,11 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||||||
|
|
||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
self.align_corners = align_corners
|
self.align_corners = align_corners
|
||||||
self.loss_decode = nn.ModuleList()
|
|
||||||
|
|
||||||
if isinstance(loss_decode, dict):
|
if isinstance(loss_decode, dict):
|
||||||
self.loss_decode.append(build_loss(loss_decode))
|
self.loss_decode = build_loss(loss_decode)
|
||||||
elif isinstance(loss_decode, (list, tuple)):
|
elif isinstance(loss_decode, (list, tuple)):
|
||||||
|
self.loss_decode = nn.ModuleList()
|
||||||
for loss in loss_decode:
|
for loss in loss_decode:
|
||||||
self.loss_decode.append(build_loss(loss))
|
self.loss_decode.append(build_loss(loss))
|
||||||
else:
|
else:
|
||||||
@ -242,7 +242,12 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||||||
else:
|
else:
|
||||||
seg_weight = None
|
seg_weight = None
|
||||||
seg_label = seg_label.squeeze(1)
|
seg_label = seg_label.squeeze(1)
|
||||||
for loss_decode in self.loss_decode:
|
|
||||||
|
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||||
|
losses_decode = [self.loss_decode]
|
||||||
|
else:
|
||||||
|
losses_decode = self.loss_decode
|
||||||
|
for loss_decode in losses_decode:
|
||||||
if loss_decode.loss_name not in loss:
|
if loss_decode.loss_name not in loss:
|
||||||
loss[loss_decode.loss_name] = loss_decode(
|
loss[loss_decode.loss_name] = loss_decode(
|
||||||
seg_logit,
|
seg_logit,
|
||||||
|
@ -249,9 +249,14 @@ class PointHead(BaseCascadeDecodeHead):
|
|||||||
def losses(self, point_logits, point_label):
|
def losses(self, point_logits, point_label):
|
||||||
"""Compute segmentation loss."""
|
"""Compute segmentation loss."""
|
||||||
loss = dict()
|
loss = dict()
|
||||||
for loss_module in self.loss_decode:
|
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||||
|
losses_decode = [self.loss_decode]
|
||||||
|
else:
|
||||||
|
losses_decode = self.loss_decode
|
||||||
|
for loss_module in losses_decode:
|
||||||
loss['point' + loss_module.loss_name] = loss_module(
|
loss['point' + loss_module.loss_name] = loss_module(
|
||||||
point_logits, point_label, ignore_index=self.ignore_index)
|
point_logits, point_label, ignore_index=self.ignore_index)
|
||||||
|
|
||||||
loss['acc_point'] = accuracy(point_logits, point_label)
|
loss['acc_point'] = accuracy(point_logits, point_label)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@ -21,3 +21,41 @@ def test_point_head():
|
|||||||
subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
|
subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
|
||||||
output = point_head.forward_test(inputs, prev_output, None, test_cfg)
|
output = point_head.forward_test(inputs, prev_output, None, test_cfg)
|
||||||
assert output.shape == (1, point_head.num_classes, 180, 180)
|
assert output.shape == (1, point_head.num_classes, 180, 180)
|
||||||
|
|
||||||
|
# test multiple losses case
|
||||||
|
inputs = [torch.randn(1, 32, 45, 45)]
|
||||||
|
point_head_multiple_losses = PointHead(
|
||||||
|
in_channels=[32],
|
||||||
|
in_index=[0],
|
||||||
|
channels=16,
|
||||||
|
num_classes=19,
|
||||||
|
loss_decode=[
|
||||||
|
dict(type='CrossEntropyLoss', loss_name='loss_1'),
|
||||||
|
dict(type='CrossEntropyLoss', loss_name='loss_2')
|
||||||
|
])
|
||||||
|
assert len(point_head_multiple_losses.fcs) == 3
|
||||||
|
fcn_head_multiple_losses = FCNHead(
|
||||||
|
in_channels=32,
|
||||||
|
channels=16,
|
||||||
|
num_classes=19,
|
||||||
|
loss_decode=[
|
||||||
|
dict(type='CrossEntropyLoss', loss_name='loss_1'),
|
||||||
|
dict(type='CrossEntropyLoss', loss_name='loss_2')
|
||||||
|
])
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
head, inputs = to_cuda(point_head_multiple_losses, inputs)
|
||||||
|
head, inputs = to_cuda(fcn_head_multiple_losses, inputs)
|
||||||
|
prev_output = fcn_head_multiple_losses(inputs)
|
||||||
|
test_cfg = ConfigDict(
|
||||||
|
subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
|
||||||
|
output = point_head_multiple_losses.forward_test(inputs, prev_output, None,
|
||||||
|
test_cfg)
|
||||||
|
assert output.shape == (1, point_head.num_classes, 180, 180)
|
||||||
|
|
||||||
|
fake_label = torch.ones([1, 180, 180], dtype=torch.long)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
fake_label = fake_label.cuda()
|
||||||
|
loss = point_head_multiple_losses.losses(output, fake_label)
|
||||||
|
assert 'pointloss_1' in loss
|
||||||
|
assert 'pointloss_2' in loss
|
||||||
|
@ -10,6 +10,17 @@ def _context_for_ohem():
|
|||||||
return FCNHead(in_channels=32, channels=16, num_classes=19)
|
return FCNHead(in_channels=32, channels=16, num_classes=19)
|
||||||
|
|
||||||
|
|
||||||
|
def _context_for_ohem_multiple_loss():
|
||||||
|
return FCNHead(
|
||||||
|
in_channels=32,
|
||||||
|
channels=16,
|
||||||
|
num_classes=19,
|
||||||
|
loss_decode=[
|
||||||
|
dict(type='CrossEntropyLoss', loss_name='loss_1'),
|
||||||
|
dict(type='CrossEntropyLoss', loss_name='loss_2')
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
def test_ohem_sampler():
|
def test_ohem_sampler():
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
@ -37,3 +48,31 @@ def test_ohem_sampler():
|
|||||||
assert seg_weight.shape[0] == seg_logit.shape[0]
|
assert seg_weight.shape[0] == seg_logit.shape[0]
|
||||||
assert seg_weight.shape[1:] == seg_logit.shape[2:]
|
assert seg_weight.shape[1:] == seg_logit.shape[2:]
|
||||||
assert seg_weight.sum() == 200
|
assert seg_weight.sum() == 200
|
||||||
|
|
||||||
|
# test multiple losses case
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# seg_logit and seg_label must be of the same size
|
||||||
|
sampler = OHEMPixelSampler(context=_context_for_ohem_multiple_loss())
|
||||||
|
seg_logit = torch.randn(1, 19, 45, 45)
|
||||||
|
seg_label = torch.randint(0, 19, size=(1, 1, 89, 89))
|
||||||
|
sampler.sample(seg_logit, seg_label)
|
||||||
|
|
||||||
|
# test with thresh in multiple losses case
|
||||||
|
sampler = OHEMPixelSampler(
|
||||||
|
context=_context_for_ohem_multiple_loss(), thresh=0.7, min_kept=200)
|
||||||
|
seg_logit = torch.randn(1, 19, 45, 45)
|
||||||
|
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
|
||||||
|
seg_weight = sampler.sample(seg_logit, seg_label)
|
||||||
|
assert seg_weight.shape[0] == seg_logit.shape[0]
|
||||||
|
assert seg_weight.shape[1:] == seg_logit.shape[2:]
|
||||||
|
assert seg_weight.sum() > 200
|
||||||
|
|
||||||
|
# test w.o thresh in multiple losses case
|
||||||
|
sampler = OHEMPixelSampler(
|
||||||
|
context=_context_for_ohem_multiple_loss(), min_kept=200)
|
||||||
|
seg_logit = torch.randn(1, 19, 45, 45)
|
||||||
|
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
|
||||||
|
seg_weight = sampler.sample(seg_logit, seg_label)
|
||||||
|
assert seg_weight.shape[0] == seg_logit.shape[0]
|
||||||
|
assert seg_weight.shape[1:] == seg_logit.shape[2:]
|
||||||
|
assert seg_weight.sum() == 200
|
||||||
|
Loading…
x
Reference in New Issue
Block a user