[Refactor] Fix STDCNet custom loss

pull/1801/head
xiexinchen.vendor 2022-07-12 10:20:41 +00:00 committed by zhengmiao
parent ad35ec6356
commit 1b3a4876a1
1 changed files with 20 additions and 7 deletions

View File

@ -1,7 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmengine.data import PixelData
from torch import Tensor
from mmseg.core.data_structures.seg_data_sample import SegDataSample
from mmseg.core.utils import SampleList
from mmseg.registry import MODELS
from .fcn_head import FCNHead
@ -31,13 +35,15 @@ class STDCHead(FCNHead):
dtype=torch.float32).reshape(1, 3, 1, 1),
requires_grad=False)
def losses(self, seg_logit, seg_label):
def loss_by_feat(self, seg_logits: Tensor,
batch_data_samples: SampleList) -> dict:
"""Compute Detail Aggregation Loss."""
# Note: The paper claims `fusion_kernel` is a trainable 1x1 conv
# parameters. However, it is a constant in original repo and other
# codebase because it would not be added into computation graph
# after threshold operation.
seg_label = seg_label.to(self.laplacian_kernel)
seg_label = self._stack_batch_gt(batch_data_samples).to(
self.laplacian_kernel)
boundary_targets = F.conv2d(
seg_label, self.laplacian_kernel, padding=1)
boundary_targets = boundary_targets.clamp(min=0)
@ -67,12 +73,12 @@ class STDCHead(FCNHead):
boundary_targets_x4_up[
boundary_targets_x4_up <= self.boundary_threshold] = 0
boudary_targets_pyramids = torch.stack(
boundary_targets_pyramids = torch.stack(
(boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up),
dim=1)
boudary_targets_pyramids = boudary_targets_pyramids.squeeze(2)
boudary_targets_pyramid = F.conv2d(boudary_targets_pyramids,
boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2)
boudary_targets_pyramid = F.conv2d(boundary_targets_pyramids,
self.fusion_kernel)
boudary_targets_pyramid[
@ -80,6 +86,13 @@ class STDCHead(FCNHead):
boudary_targets_pyramid[
boudary_targets_pyramid <= self.boundary_threshold] = 0
loss = super(STDCHead, self).losses(seg_logit,
boudary_targets_pyramid.long())
seg_labels = boudary_targets_pyramid.long()
batch_sample_list = []
for label in seg_labels:
seg_data_sample = SegDataSample()
seg_data_sample.gt_sem_seg = PixelData(data=label)
batch_sample_list.append(seg_data_sample)
loss = super(STDCHead, self).loss_by_feat(seg_logits,
batch_sample_list)
return loss