[Fix] Change `self.loss_decode` back to `dict` in Single Loss situation. (#1002)

* fix single loss type

* fix error in ohem & point_head

* fix coverage miss

* fix uncoverage error of PointHead loss

* fix coverage miss

* fix uncoverage error of PointHead loss

* nn.modules.container.ModuleList to nn.ModuleList

* more simple format

* merge unittest def
pull/1801/head
MengzhangLI 2021-11-01 15:28:37 +08:00 committed by GitHub
parent d7f82e5dc8
commit 992d577783
5 changed files with 98 additions and 5 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import PIXEL_SAMPLERS
@ -62,14 +63,19 @@ class OHEMPixelSampler(BasePixelSampler):
threshold = max(min_threshold, self.thresh)
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
else:
if not isinstance(self.context.loss_decode, nn.ModuleList):
losses_decode = [self.context.loss_decode]
else:
losses_decode = self.context.loss_decode
losses = 0.0
for loss_module in self.context.loss_decode:
for loss_module in losses_decode:
losses += loss_module(
seg_logit,
seg_label,
weight=None,
ignore_index=self.context.ignore_index,
reduction_override='none')
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
_, sort_indices = losses[valid_mask].sort(descending=True)
valid_seg_weight[sort_indices[:batch_kept]] = 1.

View File

@ -83,11 +83,11 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
self.ignore_index = ignore_index
self.align_corners = align_corners
self.loss_decode = nn.ModuleList()
if isinstance(loss_decode, dict):
self.loss_decode.append(build_loss(loss_decode))
self.loss_decode = build_loss(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
self.loss_decode = nn.ModuleList()
for loss in loss_decode:
self.loss_decode.append(build_loss(loss))
else:
@ -242,7 +242,12 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
else:
seg_weight = None
seg_label = seg_label.squeeze(1)
for loss_decode in self.loss_decode:
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_decode in losses_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(
seg_logit,

View File

@ -249,9 +249,14 @@ class PointHead(BaseCascadeDecodeHead):
def losses(self, point_logits, point_label):
"""Compute segmentation loss."""
loss = dict()
for loss_module in self.loss_decode:
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_module in losses_decode:
loss['point' + loss_module.loss_name] = loss_module(
point_logits, point_label, ignore_index=self.ignore_index)
loss['acc_point'] = accuracy(point_logits, point_label)
return loss

View File

@ -21,3 +21,41 @@ def test_point_head():
subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
output = point_head.forward_test(inputs, prev_output, None, test_cfg)
assert output.shape == (1, point_head.num_classes, 180, 180)
# test multiple losses case
inputs = [torch.randn(1, 32, 45, 45)]
point_head_multiple_losses = PointHead(
in_channels=[32],
in_index=[0],
channels=16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])
assert len(point_head_multiple_losses.fcs) == 3
fcn_head_multiple_losses = FCNHead(
in_channels=32,
channels=16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])
if torch.cuda.is_available():
head, inputs = to_cuda(point_head_multiple_losses, inputs)
head, inputs = to_cuda(fcn_head_multiple_losses, inputs)
prev_output = fcn_head_multiple_losses(inputs)
test_cfg = ConfigDict(
subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
output = point_head_multiple_losses.forward_test(inputs, prev_output, None,
test_cfg)
assert output.shape == (1, point_head.num_classes, 180, 180)
fake_label = torch.ones([1, 180, 180], dtype=torch.long)
if torch.cuda.is_available():
fake_label = fake_label.cuda()
loss = point_head_multiple_losses.losses(output, fake_label)
assert 'pointloss_1' in loss
assert 'pointloss_2' in loss

View File

@ -10,6 +10,17 @@ def _context_for_ohem():
return FCNHead(in_channels=32, channels=16, num_classes=19)
def _context_for_ohem_multiple_loss():
return FCNHead(
in_channels=32,
channels=16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])
def test_ohem_sampler():
with pytest.raises(AssertionError):
@ -37,3 +48,31 @@ def test_ohem_sampler():
assert seg_weight.shape[0] == seg_logit.shape[0]
assert seg_weight.shape[1:] == seg_logit.shape[2:]
assert seg_weight.sum() == 200
# test multiple losses case
with pytest.raises(AssertionError):
# seg_logit and seg_label must be of the same size
sampler = OHEMPixelSampler(context=_context_for_ohem_multiple_loss())
seg_logit = torch.randn(1, 19, 45, 45)
seg_label = torch.randint(0, 19, size=(1, 1, 89, 89))
sampler.sample(seg_logit, seg_label)
# test with thresh in multiple losses case
sampler = OHEMPixelSampler(
context=_context_for_ohem_multiple_loss(), thresh=0.7, min_kept=200)
seg_logit = torch.randn(1, 19, 45, 45)
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
seg_weight = sampler.sample(seg_logit, seg_label)
assert seg_weight.shape[0] == seg_logit.shape[0]
assert seg_weight.shape[1:] == seg_logit.shape[2:]
assert seg_weight.sum() > 200
# test w.o thresh in multiple losses case
sampler = OHEMPixelSampler(
context=_context_for_ohem_multiple_loss(), min_kept=200)
seg_logit = torch.randn(1, 19, 45, 45)
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
seg_weight = sampler.sample(seg_logit, seg_label)
assert seg_weight.shape[0] == seg_logit.shape[0]
assert seg_weight.shape[1:] == seg_logit.shape[2:]
assert seg_weight.sum() == 200