From 1b3a4876a1689b8f1f79a8aadf12238583440067 Mon Sep 17 00:00:00 2001 From: "xiexinchen.vendor" Date: Tue, 12 Jul 2022 10:20:41 +0000 Subject: [PATCH] [Refactor] Fix STDCNet custom loss --- mmseg/models/decode_heads/stdc_head.py | 27 +++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/mmseg/models/decode_heads/stdc_head.py b/mmseg/models/decode_heads/stdc_head.py index e18354ffb..f8601b20a 100644 --- a/mmseg/models/decode_heads/stdc_head.py +++ b/mmseg/models/decode_heads/stdc_head.py @@ -1,7 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn.functional as F +from mmengine.data import PixelData +from torch import Tensor +from mmseg.core.data_structures.seg_data_sample import SegDataSample +from mmseg.core.utils import SampleList from mmseg.registry import MODELS from .fcn_head import FCNHead @@ -31,13 +35,15 @@ class STDCHead(FCNHead): dtype=torch.float32).reshape(1, 3, 1, 1), requires_grad=False) - def losses(self, seg_logit, seg_label): + def loss_by_feat(self, seg_logits: Tensor, + batch_data_samples: SampleList) -> dict: """Compute Detail Aggregation Loss.""" # Note: The paper claims `fusion_kernel` is a trainable 1x1 conv # parameters. However, it is a constant in original repo and other # codebase because it would not be added into computation graph # after threshold operation. - seg_label = seg_label.to(self.laplacian_kernel) + seg_label = self._stack_batch_gt(batch_data_samples).to( + self.laplacian_kernel) boundary_targets = F.conv2d( seg_label, self.laplacian_kernel, padding=1) boundary_targets = boundary_targets.clamp(min=0) @@ -67,12 +73,12 @@ class STDCHead(FCNHead): boundary_targets_x4_up[ boundary_targets_x4_up <= self.boundary_threshold] = 0 - boudary_targets_pyramids = torch.stack( + boundary_targets_pyramids = torch.stack( (boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up), dim=1) - boudary_targets_pyramids = boudary_targets_pyramids.squeeze(2) - boudary_targets_pyramid = F.conv2d(boudary_targets_pyramids, + boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2) + boudary_targets_pyramid = F.conv2d(boundary_targets_pyramids, self.fusion_kernel) boudary_targets_pyramid[ @@ -80,6 +86,13 @@ class STDCHead(FCNHead): boudary_targets_pyramid[ boudary_targets_pyramid <= self.boundary_threshold] = 0 - loss = super(STDCHead, self).losses(seg_logit, - boudary_targets_pyramid.long()) + seg_labels = boudary_targets_pyramid.long() + batch_sample_list = [] + for label in seg_labels: + seg_data_sample = SegDataSample() + seg_data_sample.gt_sem_seg = PixelData(data=label) + batch_sample_list.append(seg_data_sample) + + loss = super(STDCHead, self).loss_by_feat(seg_logits, + batch_sample_list) return loss