# Copyright (c) Alibaba, Inc. and its affiliates. import logging from abc import ABCMeta, abstractmethod import torch import torch.nn as nn from mmcv.cnn.utils import initialize from easycv.core.evaluation.metrics import accuracy from easycv.models.builder import build_loss from easycv.models.utils.ops import resize_tensor from easycv.utils.logger import print_log # Modified from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/decode_head.py class BaseDecodeHead(nn.Module, metaclass=ABCMeta): """Base class for BaseDecodeHead. Args: in_channels (int|Sequence[int]): Input channels. channels (int): Channels after modules, before conv_seg. num_classes (int): Number of classes. dropout_ratio (float): Ratio of dropout layer. Default: 0.1. conv_cfg (dict|None): Config of conv layers. Default: None. norm_cfg (dict|None): Config of norm layers. Default: None. act_cfg (dict): Config of activation layers. Default: dict(type='ReLU') in_index (int|Sequence[int]): Input feature index. Default: -1 input_transform (str|None): Transformation type of input features. Options: 'resize_concat', 'multiple_select', None. 'resize_concat': Multiple feature maps will be resize to the same size as first one and than concat together. Usually used in FCN head of HRNet. 'multiple_select': Multiple feature maps will be bundle into a list and passed into decode head. None: Only one select feature map is allowed. Default: None. loss_decode (dict | Sequence[dict]): Config of decode loss. The `loss_name` is property of corresponding loss function which could be shown in training log. If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name. Defaults to 'loss_ce'. e.g. dict(type='CrossEntropyLoss'), [dict(type='CrossEntropyLoss', loss_name='loss_ce'), dict(type='DiceLoss', loss_name='loss_dice')] Default: dict(type='CrossEntropyLoss'). ignore_index (int | None): The label index to be ignored. When using masked BCE loss, ignore_index should be set to None. Default: 255. sampler (dict|None): The config of segmentation map sampler. Default: None. align_corners (bool): align_corners argument of F.interpolate. Default: False. init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, in_channels, channels, *, num_classes, dropout_ratio=0.1, conv_cfg=None, norm_cfg=None, act_cfg=dict(type='ReLU'), in_index=-1, input_transform=None, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), ignore_index=255, align_corners=False, init_cfg=dict( type='Normal', std=0.01, override=dict(name='conv_seg'))): super(BaseDecodeHead, self).__init__() self._init_inputs(in_channels, in_index, input_transform) self.channels = channels self.num_classes = num_classes self.dropout_ratio = dropout_ratio self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.in_index = in_index self.ignore_index = ignore_index self.align_corners = align_corners self.init_cfg = init_cfg if isinstance(loss_decode, dict): self.loss_decode = build_loss(loss_decode) elif isinstance(loss_decode, (list, tuple)): self.loss_decode = nn.ModuleList() for loss in loss_decode: self.loss_decode.append(build_loss(loss)) else: raise TypeError(f'loss_decode must be a dict or sequence of dict,\ but got {type(loss_decode)}') self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) if dropout_ratio > 0: self.dropout = nn.Dropout2d(dropout_ratio) else: self.dropout = None self.fp16_enabled = False def extra_repr(self): """Extra repr.""" s = f'input_transform={self.input_transform}, ' \ f'ignore_index={self.ignore_index}, ' \ f'align_corners={self.align_corners}' return s def _init_inputs(self, in_channels, in_index, input_transform): """Check and initialize input transforms. The in_channels, in_index and input_transform must match. Specifically, when input_transform is None, only single feature map will be selected. So in_channels and in_index must be of type int. When input_transform Args: in_channels (int|Sequence[int]): Input channels. in_index (int|Sequence[int]): Input feature index. input_transform (str|None): Transformation type of input features. Options: 'resize_concat', 'multiple_select', None. 'resize_concat': Multiple feature maps will be resize to the same size as first one and than concat together. Usually used in FCN head of HRNet. 'multiple_select': Multiple feature maps will be bundle into a list and passed into decode head. None: Only one select feature map is allowed. """ if input_transform is not None: assert input_transform in ['resize_concat', 'multiple_select'] self.input_transform = input_transform self.in_index = in_index if input_transform is not None: assert isinstance(in_channels, (list, tuple)) assert isinstance(in_index, (list, tuple)) assert len(in_channels) == len(in_index) if input_transform == 'resize_concat': self.in_channels = sum(in_channels) else: self.in_channels = in_channels else: assert isinstance(in_channels, int) assert isinstance(in_index, int) self.in_channels = in_channels def _transform_inputs(self, inputs): """Transform inputs for decoder. Args: inputs (list[Tensor]): List of multi-level img features. Returns: Tensor: The transformed inputs """ if self.input_transform == 'resize_concat': inputs = [inputs[i] for i in self.in_index] upsampled_inputs = [ resize_tensor( input=x, size=inputs[0].shape[2:], mode='bilinear', align_corners=self.align_corners) for x in inputs ] inputs = torch.cat(upsampled_inputs, dim=1) elif self.input_transform == 'multiple_select': inputs = [inputs[i] for i in self.in_index] else: inputs = inputs[self.in_index] return inputs @abstractmethod def forward(self, inputs): """Placeholder of forward function.""" pass def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): """Forward function for training. Args: inputs (list[Tensor]): List of multi-level img features. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:Collect`. gt_semantic_seg (Tensor): Semantic segmentation masks used if the architecture supports semantic segmentation task. train_cfg (dict): The training config. Returns: dict[str, Tensor]: a dictionary of loss components """ seg_logits = self.forward(inputs) losses = self.losses(seg_logits, gt_semantic_seg) return losses def forward_test(self, inputs, img_metas, test_cfg): """Forward function for testing. Args: inputs (list[Tensor]): List of multi-level img features. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:Collect`. test_cfg (dict): The testing config. Returns: Tensor: Output segmentation map. """ return self.forward(inputs) def cls_seg(self, feat): """Classify each pixel.""" if self.dropout is not None: feat = self.dropout(feat) output = self.conv_seg(feat) return output def losses(self, seg_logit, seg_label): """Compute segmentation loss.""" loss = dict() seg_logit = resize_tensor( input=seg_logit, size=seg_label.shape[2:], mode='bilinear', align_corners=self.align_corners) seg_label = seg_label.squeeze(1) if not isinstance(self.loss_decode, nn.ModuleList): losses_decode = [self.loss_decode] else: losses_decode = self.loss_decode for loss_decode in losses_decode: if loss_decode.loss_name not in loss: loss[loss_decode.loss_name] = loss_decode( seg_logit, seg_label, ignore_index=self.ignore_index) else: loss[loss_decode.loss_name] += loss_decode( seg_logit, seg_label, ignore_index=self.ignore_index) loss['acc_seg'] = accuracy( seg_logit, seg_label, ignore_index=self.ignore_index) return loss def init_weights(self): module_name = self.__class__.__name__ if self.init_cfg: print_log( f'initialize {module_name} with init_cfg {self.init_cfg}') initialize(self, self.init_cfg) if isinstance(self.init_cfg, dict): # prevent the parameters of the pre-trained model from being overwritten by the `init_weights` if self.init_cfg['type'] == 'Pretrained': logging.warning('Skip `init_cfg` with `Pretrained` type!') return for m in self.children(): if hasattr(m, 'init_weights'): m.init_weights()