[Refactor] Fix STDCNet custom loss
parent
ad35ec6356
commit
1b3a4876a1
|
@ -1,7 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.data import PixelData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.core.data_structures.seg_data_sample import SegDataSample
|
||||
from mmseg.core.utils import SampleList
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
@ -31,13 +35,15 @@ class STDCHead(FCNHead):
|
|||
dtype=torch.float32).reshape(1, 3, 1, 1),
|
||||
requires_grad=False)
|
||||
|
||||
def losses(self, seg_logit, seg_label):
|
||||
def loss_by_feat(self, seg_logits: Tensor,
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
"""Compute Detail Aggregation Loss."""
|
||||
# Note: The paper claims `fusion_kernel` is a trainable 1x1 conv
|
||||
# parameters. However, it is a constant in original repo and other
|
||||
# codebase because it would not be added into computation graph
|
||||
# after threshold operation.
|
||||
seg_label = seg_label.to(self.laplacian_kernel)
|
||||
seg_label = self._stack_batch_gt(batch_data_samples).to(
|
||||
self.laplacian_kernel)
|
||||
boundary_targets = F.conv2d(
|
||||
seg_label, self.laplacian_kernel, padding=1)
|
||||
boundary_targets = boundary_targets.clamp(min=0)
|
||||
|
@ -67,12 +73,12 @@ class STDCHead(FCNHead):
|
|||
boundary_targets_x4_up[
|
||||
boundary_targets_x4_up <= self.boundary_threshold] = 0
|
||||
|
||||
boudary_targets_pyramids = torch.stack(
|
||||
boundary_targets_pyramids = torch.stack(
|
||||
(boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up),
|
||||
dim=1)
|
||||
|
||||
boudary_targets_pyramids = boudary_targets_pyramids.squeeze(2)
|
||||
boudary_targets_pyramid = F.conv2d(boudary_targets_pyramids,
|
||||
boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2)
|
||||
boudary_targets_pyramid = F.conv2d(boundary_targets_pyramids,
|
||||
self.fusion_kernel)
|
||||
|
||||
boudary_targets_pyramid[
|
||||
|
@ -80,6 +86,13 @@ class STDCHead(FCNHead):
|
|||
boudary_targets_pyramid[
|
||||
boudary_targets_pyramid <= self.boundary_threshold] = 0
|
||||
|
||||
loss = super(STDCHead, self).losses(seg_logit,
|
||||
boudary_targets_pyramid.long())
|
||||
seg_labels = boudary_targets_pyramid.long()
|
||||
batch_sample_list = []
|
||||
for label in seg_labels:
|
||||
seg_data_sample = SegDataSample()
|
||||
seg_data_sample.gt_sem_seg = PixelData(data=label)
|
||||
batch_sample_list.append(seg_data_sample)
|
||||
|
||||
loss = super(STDCHead, self).loss_by_feat(seg_logits,
|
||||
batch_sample_list)
|
||||
return loss
|
||||
|
|
Loading…
Reference in New Issue