MengzhangLI c1e72513c9 [Feature] Support STDC Network (new) (#995)
* refactor stdc code

* update key

* fix backbone inference

* remove comments

* fixing errors

* fixing version conflict

* fux typo

* use STDCHead

* upload models&logs

* adding model converters script and fix unittest

* fix error

* fix error

* fix error

* delete redundant keys in config

* fix errors in configs and unittest

* fix errors in configs and unittest

* fix errors in configs and unittest

* change Memory name

* refactor stdc2mmseg

* change name to STDC

* refactor stdc

* refactor stdc

* stdc refactor

* stdc refactor

* stdc refactor

* stdc refactor

* stdc refactor

* stdc refactor

* refactor stdc

* stdc refactor

Co-authored-by: xiexinch <xinchen.xie@qq.com>
2021-12-10 23:09:32 +08:00

91 lines
3.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from ..builder import HEADS
from .fcn_head import FCNHead
@HEADS.register_module()
class STDCHead(FCNHead):
"""This head is the implementation of `Rethinking BiSeNet For Real-time
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
Args:
boundary_threshold (float): The threshold of calculating boundary.
Default: 0.1.
"""
def __init__(self, boundary_threshold=0.1, **kwargs):
super(STDCHead, self).__init__(**kwargs)
self.boundary_threshold = boundary_threshold
# Using register buffer to make laplacian kernel on the same
# device of `seg_label`.
self.register_buffer(
'laplacian_kernel',
torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1],
dtype=torch.float32,
requires_grad=False).reshape((1, 1, 3, 3)))
self.fusion_kernel = torch.nn.Parameter(
torch.tensor([[6. / 10], [3. / 10], [1. / 10]],
dtype=torch.float32).reshape(1, 3, 1, 1),
requires_grad=False)
def losses(self, seg_logit, seg_label):
"""Compute Detail Aggregation Loss."""
# Note: The paper claims `fusion_kernel` is a trainable 1x1 conv
# parameters. However, it is a constant in original repo and other
# codebase because it would not be added into computation graph
# after threshold operation.
seg_label = seg_label.float()
boundary_targets = F.conv2d(
seg_label, self.laplacian_kernel, padding=1)
boundary_targets = boundary_targets.clamp(min=0)
boundary_targets[boundary_targets > self.boundary_threshold] = 1
boundary_targets[boundary_targets <= self.boundary_threshold] = 0
boundary_targets_x2 = F.conv2d(
seg_label, self.laplacian_kernel, stride=2, padding=1)
boundary_targets_x2 = boundary_targets_x2.clamp(min=0)
boundary_targets_x4 = F.conv2d(
seg_label, self.laplacian_kernel, stride=4, padding=1)
boundary_targets_x4 = boundary_targets_x4.clamp(min=0)
boundary_targets_x4_up = F.interpolate(
boundary_targets_x4, boundary_targets.shape[2:], mode='nearest')
boundary_targets_x2_up = F.interpolate(
boundary_targets_x2, boundary_targets.shape[2:], mode='nearest')
boundary_targets_x2_up[
boundary_targets_x2_up > self.boundary_threshold] = 1
boundary_targets_x2_up[
boundary_targets_x2_up <= self.boundary_threshold] = 0
boundary_targets_x4_up[
boundary_targets_x4_up > self.boundary_threshold] = 1
boundary_targets_x4_up[
boundary_targets_x4_up <= self.boundary_threshold] = 0
boudary_targets_pyramids = torch.stack(
(boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up),
dim=1)
boudary_targets_pyramids = boudary_targets_pyramids.squeeze(2)
boudary_targets_pyramid = F.conv2d(boudary_targets_pyramids,
self.fusion_kernel)
boudary_targets_pyramid[
boudary_targets_pyramid > self.boundary_threshold] = 1
boudary_targets_pyramid[
boudary_targets_pyramid <= self.boundary_threshold] = 0
seg_logit = F.interpolate(
seg_logit,
boundary_targets.shape[2:],
mode='bilinear',
align_corners=True)
loss = super(STDCHead, self).losses(seg_logit,
boudary_targets_pyramid.long())
return loss