mmsegmentation/projects/hssn/decode_head/sep_aspp_contrast_head.py
Tianlong Ai 432628b735
[Fix] Rename and Fix bug of projects HieraSeg (old PR #2444) (#2565)
## Motivation
Supplementary PR #2444 
Fix tiny bug and add loss_by_feat() to compute loss to train.
The inference process have verified to be accurate.
## Modification
- modify `sep_aspp_contrast_head.py` , add `loss_by_feat()` function to
train(training still has bug, will fix in future😫)
- fix testing commands path error `bash tools/dist_test.sh
projects/HieraSeg_project/` to `bash tools/dist_test.sh
projects/HieraSeg/` at README.md
2023-02-06 18:55:22 +08:00

187 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from torch import Tensor
from mmseg.models.decode_heads.sep_aspp_head import DepthwiseSeparableASPPHead
from mmseg.models.losses import accuracy
from mmseg.models.utils import resize
from mmseg.registry import MODELS
from mmseg.utils import SampleList
class ProjectionHead(nn.Module):
"""ProjectionHead, project feature map to specific channels.
Args:
dim_in (int): Input channels.
norm_cfg (dict): config of norm layer.
proj_dim (int): Output channels. Default: 256.
proj (str): Projection type, 'linear' or 'convmlp'. Default: 'convmlp'
"""
def __init__(self,
dim_in: int,
norm_cfg: dict,
proj_dim: int = 256,
proj: str = 'convmlp'):
super().__init__()
assert proj in ['convmlp', 'linear']
if proj == 'linear':
self.proj = nn.Conv2d(dim_in, proj_dim, kernel_size=1)
elif proj == 'convmlp':
self.proj = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size=1),
build_norm_layer(norm_cfg, dim_in)[1], nn.ReLU(inplace=True),
nn.Conv2d(dim_in, proj_dim, kernel_size=1))
def forward(self, x):
return torch.nn.functional.normalize(self.proj(x), p=2, dim=1)
@MODELS.register_module()
class DepthwiseSeparableASPPContrastHead(DepthwiseSeparableASPPHead):
"""Deep Hierarchical Semantic Segmentation. This head is the implementation
of `<https://arxiv.org/abs/2203.14335>`_.
Based on Encoder-Decoder with Atrous Separable Convolution for
Semantic Image Segmentation.
`DeepLabV3+ <https://arxiv.org/abs/1802.02611>`_.
Args:
proj (str): The type of ProjectionHead, 'linear' or 'convmlp',
default 'convmlp'
"""
def __init__(self, proj: str = 'convmlp', **kwargs):
super().__init__(**kwargs)
self.proj_head = ProjectionHead(
dim_in=2048, norm_cfg=self.norm_cfg, proj=proj)
self.register_buffer('step', torch.zeros(1))
def forward(self, inputs) -> Tuple[Tensor]:
"""Forward function."""
output = super().forward(inputs)
self.step += 1
embedding = self.proj_head(inputs[-1])
return output, embedding
def predict_by_feat(self, seg_logits: Tuple[Tensor],
batch_img_metas: List[dict]) -> Tensor:
"""Transform a batch of output seg_logits to the input shape.
Args:
seg_logits (Tensor): The output from decode head forward function.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
Returns:
Tensor: Outputs segmentation logits map.
"""
# HSSN decode_head output is: (out, embedding): tuple
# only need 'out' here.
if isinstance(seg_logits, tuple):
seg_logit = seg_logits[0]
if seg_logit.size(1) == 26: # For cityscapes dataset19 + 7
hiera_num_classes = 7
seg_logit[:, 0:2] += seg_logit[:, -7]
seg_logit[:, 2:5] += seg_logit[:, -6]
seg_logit[:, 5:8] += seg_logit[:, -5]
seg_logit[:, 8:10] += seg_logit[:, -4]
seg_logit[:, 10:11] += seg_logit[:, -3]
seg_logit[:, 11:13] += seg_logit[:, -2]
seg_logit[:, 13:19] += seg_logit[:, -1]
elif seg_logit.size(1) == 12: # For Pascal_person dataset, 7 + 5
hiera_num_classes = 5
seg_logit[:, 0:1] = seg_logit[:, 0:1] + \
seg_logit[:, 7] + seg_logit[:, 10]
seg_logit[:, 1:5] = seg_logit[:, 1:5] + \
seg_logit[:, 8] + seg_logit[:, 11]
seg_logit[:, 5:7] = seg_logit[:, 5:7] + \
seg_logit[:, 9] + seg_logit[:, 11]
elif seg_logit.size(1) == 25: # For LIP dataset, 20 + 5
hiera_num_classes = 5
seg_logit[:, 0:1] = seg_logit[:, 0:1] + \
seg_logit[:, 20] + seg_logit[:, 23]
seg_logit[:, 1:8] = seg_logit[:, 1:8] + \
seg_logit[:, 21] + seg_logit[:, 24]
seg_logit[:, 10:12] = seg_logit[:, 10:12] + \
seg_logit[:, 21] + seg_logit[:, 24]
seg_logit[:, 13:16] = seg_logit[:, 13:16] + \
seg_logit[:, 21] + seg_logit[:, 24]
seg_logit[:, 8:10] = seg_logit[:, 8:10] + \
seg_logit[:, 22] + seg_logit[:, 24]
seg_logit[:, 12:13] = seg_logit[:, 12:13] + \
seg_logit[:, 22] + seg_logit[:, 24]
seg_logit[:, 16:20] = seg_logit[:, 16:20] + \
seg_logit[:, 22] + seg_logit[:, 24]
# elif seg_logit.size(1) == 144 # For Mapillary dataset, 124+16+4
# unofficial repository not release mapillary until 2023/2/6
seg_logit = seg_logit[:, :-hiera_num_classes]
seg_logit = resize(
input=seg_logit,
size=batch_img_metas[0]['img_shape'],
mode='bilinear',
align_corners=self.align_corners)
return seg_logit
def loss_by_feat(
self,
seg_logits: Tuple[Tensor], # (out, embedding)
batch_data_samples: SampleList) -> dict:
"""Compute segmentation loss. Will fix in future.
Args:
seg_logits (Tuple[Tensor]): The output from decode head
forward function.
For this decode_head output are (out, embedding): tuple
batch_data_samples (List[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_sem_seg`.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logit_before = seg_logits[0]
embedding = seg_logits[1]
seg_label = self._stack_batch_gt(batch_data_samples)
loss = dict()
seg_logit = resize(
input=seg_logit_before,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
if self.sampler is not None:
seg_weight = self.sampler.sample(seg_logit, seg_label)
else:
seg_weight = None
seg_label = seg_label.squeeze(1)
seg_logit_before = resize(
input=seg_logit_before,
scale_factor=0.5,
mode='bilinear',
align_corners=self.align_corners)
loss['loss_seg'] = self.loss_decode(
self.step,
embedding,
seg_logit_before,
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
loss['acc_seg'] = accuracy(seg_logit, seg_label)
return loss