mirror of https://github.com/alibaba/EasyCV.git
270 lines
11 KiB
Python
270 lines
11 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import logging
|
|
from abc import ABCMeta, abstractmethod
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn.utils import initialize
|
|
|
|
from easycv.core.evaluation.metrics import accuracy
|
|
from easycv.models.builder import build_loss
|
|
from easycv.models.utils.ops import resize_tensor
|
|
from easycv.utils.logger import print_log
|
|
|
|
|
|
# Modified from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/decode_head.py
|
|
class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
|
"""Base class for BaseDecodeHead.
|
|
|
|
Args:
|
|
in_channels (int|Sequence[int]): Input channels.
|
|
channels (int): Channels after modules, before conv_seg.
|
|
num_classes (int): Number of classes.
|
|
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
|
conv_cfg (dict|None): Config of conv layers. Default: None.
|
|
norm_cfg (dict|None): Config of norm layers. Default: None.
|
|
act_cfg (dict): Config of activation layers.
|
|
Default: dict(type='ReLU')
|
|
in_index (int|Sequence[int]): Input feature index. Default: -1
|
|
input_transform (str|None): Transformation type of input features.
|
|
Options: 'resize_concat', 'multiple_select', None.
|
|
'resize_concat': Multiple feature maps will be resize to the
|
|
same size as first one and than concat together.
|
|
Usually used in FCN head of HRNet.
|
|
'multiple_select': Multiple feature maps will be bundle into
|
|
a list and passed into decode head.
|
|
None: Only one select feature map is allowed.
|
|
Default: None.
|
|
loss_decode (dict | Sequence[dict]): Config of decode loss.
|
|
The `loss_name` is property of corresponding loss function which
|
|
could be shown in training log. If you want this loss
|
|
item to be included into the backward graph, `loss_` must be the
|
|
prefix of the name. Defaults to 'loss_ce'.
|
|
e.g. dict(type='CrossEntropyLoss'),
|
|
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
|
|
dict(type='DiceLoss', loss_name='loss_dice')]
|
|
Default: dict(type='CrossEntropyLoss').
|
|
ignore_index (int | None): The label index to be ignored. When using
|
|
masked BCE loss, ignore_index should be set to None. Default: 255.
|
|
sampler (dict|None): The config of segmentation map sampler.
|
|
Default: None.
|
|
align_corners (bool): align_corners argument of F.interpolate.
|
|
Default: False.
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
channels,
|
|
*,
|
|
num_classes,
|
|
dropout_ratio=0.1,
|
|
conv_cfg=None,
|
|
norm_cfg=None,
|
|
act_cfg=dict(type='ReLU'),
|
|
in_index=-1,
|
|
input_transform=None,
|
|
loss_decode=dict(
|
|
type='CrossEntropyLoss',
|
|
use_sigmoid=False,
|
|
loss_weight=1.0),
|
|
ignore_index=255,
|
|
align_corners=False,
|
|
init_cfg=dict(
|
|
type='Normal', std=0.01, override=dict(name='conv_seg'))):
|
|
super(BaseDecodeHead, self).__init__()
|
|
self._init_inputs(in_channels, in_index, input_transform)
|
|
self.channels = channels
|
|
self.num_classes = num_classes
|
|
self.dropout_ratio = dropout_ratio
|
|
self.conv_cfg = conv_cfg
|
|
self.norm_cfg = norm_cfg
|
|
self.act_cfg = act_cfg
|
|
self.in_index = in_index
|
|
|
|
self.ignore_index = ignore_index
|
|
self.align_corners = align_corners
|
|
self.init_cfg = init_cfg
|
|
|
|
if isinstance(loss_decode, dict):
|
|
self.loss_decode = build_loss(loss_decode)
|
|
elif isinstance(loss_decode, (list, tuple)):
|
|
self.loss_decode = nn.ModuleList()
|
|
for loss in loss_decode:
|
|
self.loss_decode.append(build_loss(loss))
|
|
else:
|
|
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
|
|
but got {type(loss_decode)}')
|
|
|
|
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
|
if dropout_ratio > 0:
|
|
self.dropout = nn.Dropout2d(dropout_ratio)
|
|
else:
|
|
self.dropout = None
|
|
self.fp16_enabled = False
|
|
|
|
def extra_repr(self):
|
|
"""Extra repr."""
|
|
s = f'input_transform={self.input_transform}, ' \
|
|
f'ignore_index={self.ignore_index}, ' \
|
|
f'align_corners={self.align_corners}'
|
|
return s
|
|
|
|
def _init_inputs(self, in_channels, in_index, input_transform):
|
|
"""Check and initialize input transforms.
|
|
|
|
The in_channels, in_index and input_transform must match.
|
|
Specifically, when input_transform is None, only single feature map
|
|
will be selected. So in_channels and in_index must be of type int.
|
|
When input_transform
|
|
|
|
Args:
|
|
in_channels (int|Sequence[int]): Input channels.
|
|
in_index (int|Sequence[int]): Input feature index.
|
|
input_transform (str|None): Transformation type of input features.
|
|
Options: 'resize_concat', 'multiple_select', None.
|
|
'resize_concat': Multiple feature maps will be resize to the
|
|
same size as first one and than concat together.
|
|
Usually used in FCN head of HRNet.
|
|
'multiple_select': Multiple feature maps will be bundle into
|
|
a list and passed into decode head.
|
|
None: Only one select feature map is allowed.
|
|
"""
|
|
|
|
if input_transform is not None:
|
|
assert input_transform in ['resize_concat', 'multiple_select']
|
|
self.input_transform = input_transform
|
|
self.in_index = in_index
|
|
if input_transform is not None:
|
|
assert isinstance(in_channels, (list, tuple))
|
|
assert isinstance(in_index, (list, tuple))
|
|
assert len(in_channels) == len(in_index)
|
|
if input_transform == 'resize_concat':
|
|
self.in_channels = sum(in_channels)
|
|
else:
|
|
self.in_channels = in_channels
|
|
else:
|
|
assert isinstance(in_channels, int)
|
|
assert isinstance(in_index, int)
|
|
self.in_channels = in_channels
|
|
|
|
def _transform_inputs(self, inputs):
|
|
"""Transform inputs for decoder.
|
|
|
|
Args:
|
|
inputs (list[Tensor]): List of multi-level img features.
|
|
|
|
Returns:
|
|
Tensor: The transformed inputs
|
|
"""
|
|
|
|
if self.input_transform == 'resize_concat':
|
|
inputs = [inputs[i] for i in self.in_index]
|
|
upsampled_inputs = [
|
|
resize_tensor(
|
|
input=x,
|
|
size=inputs[0].shape[2:],
|
|
mode='bilinear',
|
|
align_corners=self.align_corners) for x in inputs
|
|
]
|
|
inputs = torch.cat(upsampled_inputs, dim=1)
|
|
elif self.input_transform == 'multiple_select':
|
|
inputs = [inputs[i] for i in self.in_index]
|
|
else:
|
|
inputs = inputs[self.in_index]
|
|
|
|
return inputs
|
|
|
|
@abstractmethod
|
|
def forward(self, inputs):
|
|
"""Placeholder of forward function."""
|
|
pass
|
|
|
|
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
|
|
"""Forward function for training.
|
|
Args:
|
|
inputs (list[Tensor]): List of multi-level img features.
|
|
img_metas (list[dict]): List of image info dict where each dict
|
|
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
|
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
|
For details on the values of these keys see
|
|
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
|
gt_semantic_seg (Tensor): Semantic segmentation masks
|
|
used if the architecture supports semantic segmentation task.
|
|
train_cfg (dict): The training config.
|
|
|
|
Returns:
|
|
dict[str, Tensor]: a dictionary of loss components
|
|
"""
|
|
seg_logits = self.forward(inputs)
|
|
losses = self.losses(seg_logits, gt_semantic_seg)
|
|
return losses
|
|
|
|
def forward_test(self, inputs, img_metas, test_cfg):
|
|
"""Forward function for testing.
|
|
|
|
Args:
|
|
inputs (list[Tensor]): List of multi-level img features.
|
|
img_metas (list[dict]): List of image info dict where each dict
|
|
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
|
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
|
For details on the values of these keys see
|
|
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
|
test_cfg (dict): The testing config.
|
|
|
|
Returns:
|
|
Tensor: Output segmentation map.
|
|
"""
|
|
return self.forward(inputs)
|
|
|
|
def cls_seg(self, feat):
|
|
"""Classify each pixel."""
|
|
if self.dropout is not None:
|
|
feat = self.dropout(feat)
|
|
output = self.conv_seg(feat)
|
|
return output
|
|
|
|
def losses(self, seg_logit, seg_label):
|
|
"""Compute segmentation loss."""
|
|
loss = dict()
|
|
seg_logit = resize_tensor(
|
|
input=seg_logit,
|
|
size=seg_label.shape[2:],
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
|
|
seg_label = seg_label.squeeze(1)
|
|
|
|
if not isinstance(self.loss_decode, nn.ModuleList):
|
|
losses_decode = [self.loss_decode]
|
|
else:
|
|
losses_decode = self.loss_decode
|
|
for loss_decode in losses_decode:
|
|
if loss_decode.loss_name not in loss:
|
|
loss[loss_decode.loss_name] = loss_decode(
|
|
seg_logit, seg_label, ignore_index=self.ignore_index)
|
|
else:
|
|
loss[loss_decode.loss_name] += loss_decode(
|
|
seg_logit, seg_label, ignore_index=self.ignore_index)
|
|
|
|
loss['acc_seg'] = accuracy(
|
|
seg_logit, seg_label, ignore_index=self.ignore_index)
|
|
return loss
|
|
|
|
def init_weights(self):
|
|
module_name = self.__class__.__name__
|
|
|
|
if self.init_cfg:
|
|
print_log(
|
|
f'initialize {module_name} with init_cfg {self.init_cfg}')
|
|
initialize(self, self.init_cfg)
|
|
if isinstance(self.init_cfg, dict):
|
|
# prevent the parameters of the pre-trained model from being overwritten by the `init_weights`
|
|
if self.init_cfg['type'] == 'Pretrained':
|
|
logging.warning('Skip `init_cfg` with `Pretrained` type!')
|
|
return
|
|
|
|
for m in self.children():
|
|
if hasattr(m, 'init_weights'):
|
|
m.init_weights()
|