mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
95 lines
3.3 KiB
Python
95 lines
3.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import List
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import ConvModule, build_conv_layer
|
|
from torch import Tensor
|
|
|
|
from mmseg.registry import MODELS
|
|
|
|
|
|
@MODELS.register_module()
|
|
class PPMobileSegHead(nn.Module):
|
|
"""the segmentation head.
|
|
|
|
Args:
|
|
num_classes (int): the classes num.
|
|
in_channels (int): the input channels.
|
|
use_dw (bool): if to use deepwith convolution.
|
|
dropout_ratio (float): Probability of an element to be zeroed.
|
|
Default 0.0。
|
|
align_corners (bool, optional): Geometrically, we consider the pixels
|
|
of the input and output as squares rather than points.
|
|
upsample (str): the upsample method.
|
|
out_channels (int): the output channel.
|
|
conv_cfg (dict): Config dict for convolution layer.
|
|
act_cfg (dict): Config dict for activation layer.
|
|
Default: dict(type='ReLU').
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
"""
|
|
|
|
def __init__(self,
|
|
num_classes,
|
|
in_channels,
|
|
use_dw=True,
|
|
dropout_ratio=0.1,
|
|
align_corners=False,
|
|
upsample='intepolate',
|
|
out_channels=None,
|
|
conv_cfg=dict(type='Conv'),
|
|
act_cfg=dict(type='ReLU'),
|
|
norm_cfg=dict(type='BN')):
|
|
|
|
super().__init__()
|
|
self.align_corners = align_corners
|
|
self.last_channels = in_channels
|
|
self.upsample = upsample
|
|
self.num_classes = num_classes
|
|
self.out_channels = out_channels
|
|
self.linear_fuse = ConvModule(
|
|
in_channels=self.last_channels,
|
|
out_channels=self.last_channels,
|
|
kernel_size=1,
|
|
bias=False,
|
|
groups=self.last_channels if use_dw else 1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
self.dropout = nn.Dropout2d(dropout_ratio)
|
|
self.conv_seg = build_conv_layer(
|
|
conv_cfg, self.last_channels, self.num_classes, kernel_size=1)
|
|
|
|
def forward(self, x):
|
|
x, x_hw = x[0], x[1]
|
|
x = self.linear_fuse(x)
|
|
x = self.dropout(x)
|
|
x = self.conv_seg(x)
|
|
if self.upsample == 'intepolate' or self.training or \
|
|
self.num_classes < 30:
|
|
x = F.interpolate(
|
|
x, x_hw, mode='bilinear', align_corners=self.align_corners)
|
|
elif self.upsample == 'vim':
|
|
labelset = torch.unique(torch.argmax(x, 1))
|
|
x = torch.gather(x, 1, labelset)
|
|
x = F.interpolate(
|
|
x, x_hw, mode='bilinear', align_corners=self.align_corners)
|
|
|
|
pred = torch.argmax(x, 1)
|
|
pred_retrieve = torch.zeros(pred.shape, dtype=torch.int32)
|
|
for i, val in enumerate(labelset):
|
|
pred_retrieve[pred == i] = labelset[i].cast('int32')
|
|
|
|
x = pred_retrieve
|
|
else:
|
|
raise NotImplementedError(self.upsample, ' is not implemented')
|
|
|
|
return [x]
|
|
|
|
def predict(self, inputs, batch_img_metas: List[dict], test_cfg,
|
|
**kwargs) -> List[Tensor]:
|
|
"""Forward function for testing, only ``pam_cam`` is used."""
|
|
seg_logits = self.forward(inputs)[0]
|
|
return seg_logits
|