mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Refactor] Fix STDCNet custom loss
This commit is contained in:
parent
ad35ec6356
commit
1b3a4876a1
@ -1,7 +1,11 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from mmengine.data import PixelData
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from mmseg.core.data_structures.seg_data_sample import SegDataSample
|
||||||
|
from mmseg.core.utils import SampleList
|
||||||
from mmseg.registry import MODELS
|
from mmseg.registry import MODELS
|
||||||
from .fcn_head import FCNHead
|
from .fcn_head import FCNHead
|
||||||
|
|
||||||
@ -31,13 +35,15 @@ class STDCHead(FCNHead):
|
|||||||
dtype=torch.float32).reshape(1, 3, 1, 1),
|
dtype=torch.float32).reshape(1, 3, 1, 1),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
def losses(self, seg_logit, seg_label):
|
def loss_by_feat(self, seg_logits: Tensor,
|
||||||
|
batch_data_samples: SampleList) -> dict:
|
||||||
"""Compute Detail Aggregation Loss."""
|
"""Compute Detail Aggregation Loss."""
|
||||||
# Note: The paper claims `fusion_kernel` is a trainable 1x1 conv
|
# Note: The paper claims `fusion_kernel` is a trainable 1x1 conv
|
||||||
# parameters. However, it is a constant in original repo and other
|
# parameters. However, it is a constant in original repo and other
|
||||||
# codebase because it would not be added into computation graph
|
# codebase because it would not be added into computation graph
|
||||||
# after threshold operation.
|
# after threshold operation.
|
||||||
seg_label = seg_label.to(self.laplacian_kernel)
|
seg_label = self._stack_batch_gt(batch_data_samples).to(
|
||||||
|
self.laplacian_kernel)
|
||||||
boundary_targets = F.conv2d(
|
boundary_targets = F.conv2d(
|
||||||
seg_label, self.laplacian_kernel, padding=1)
|
seg_label, self.laplacian_kernel, padding=1)
|
||||||
boundary_targets = boundary_targets.clamp(min=0)
|
boundary_targets = boundary_targets.clamp(min=0)
|
||||||
@ -67,12 +73,12 @@ class STDCHead(FCNHead):
|
|||||||
boundary_targets_x4_up[
|
boundary_targets_x4_up[
|
||||||
boundary_targets_x4_up <= self.boundary_threshold] = 0
|
boundary_targets_x4_up <= self.boundary_threshold] = 0
|
||||||
|
|
||||||
boudary_targets_pyramids = torch.stack(
|
boundary_targets_pyramids = torch.stack(
|
||||||
(boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up),
|
(boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up),
|
||||||
dim=1)
|
dim=1)
|
||||||
|
|
||||||
boudary_targets_pyramids = boudary_targets_pyramids.squeeze(2)
|
boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2)
|
||||||
boudary_targets_pyramid = F.conv2d(boudary_targets_pyramids,
|
boudary_targets_pyramid = F.conv2d(boundary_targets_pyramids,
|
||||||
self.fusion_kernel)
|
self.fusion_kernel)
|
||||||
|
|
||||||
boudary_targets_pyramid[
|
boudary_targets_pyramid[
|
||||||
@ -80,6 +86,13 @@ class STDCHead(FCNHead):
|
|||||||
boudary_targets_pyramid[
|
boudary_targets_pyramid[
|
||||||
boudary_targets_pyramid <= self.boundary_threshold] = 0
|
boudary_targets_pyramid <= self.boundary_threshold] = 0
|
||||||
|
|
||||||
loss = super(STDCHead, self).losses(seg_logit,
|
seg_labels = boudary_targets_pyramid.long()
|
||||||
boudary_targets_pyramid.long())
|
batch_sample_list = []
|
||||||
|
for label in seg_labels:
|
||||||
|
seg_data_sample = SegDataSample()
|
||||||
|
seg_data_sample.gt_sem_seg = PixelData(data=label)
|
||||||
|
batch_sample_list.append(seg_data_sample)
|
||||||
|
|
||||||
|
loss = super(STDCHead, self).loss_by_feat(seg_logits,
|
||||||
|
batch_sample_list)
|
||||||
return loss
|
return loss
|
||||||
|
Loading…
x
Reference in New Issue
Block a user