58 lines
2.3 KiB
Python
58 lines
2.3 KiB
Python
from abc import ABCMeta, abstractmethod
|
|
|
|
from .decode_head import BaseDecodeHead
|
|
|
|
|
|
class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
|
|
"""Base class for cascade decode head used in
|
|
:class:`CascadeEncoderDecoder."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs)
|
|
|
|
@abstractmethod
|
|
def forward(self, inputs, prev_output):
|
|
"""Placeholder of forward function."""
|
|
pass
|
|
|
|
def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
|
|
train_cfg):
|
|
"""Forward function for training.
|
|
Args:
|
|
inputs (list[Tensor]): List of multi-level img features.
|
|
prev_output (Tensor): The output of previous decode head.
|
|
img_metas (list[dict]): List of image info dict where each dict
|
|
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
|
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
|
For details on the values of these keys see
|
|
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
|
gt_semantic_seg (Tensor): Semantic segmentation masks
|
|
used if the architecture supports semantic segmentation task.
|
|
train_cfg (dict): The training config.
|
|
|
|
Returns:
|
|
dict[str, Tensor]: a dictionary of loss components
|
|
"""
|
|
seg_logits = self.forward(inputs, prev_output)
|
|
losses = self.losses(seg_logits, gt_semantic_seg)
|
|
|
|
return losses
|
|
|
|
def forward_test(self, inputs, prev_output, img_metas, test_cfg):
|
|
"""Forward function for testing.
|
|
|
|
Args:
|
|
inputs (list[Tensor]): List of multi-level img features.
|
|
prev_output (Tensor): The output of previous decode head.
|
|
img_metas (list[dict]): List of image info dict where each dict
|
|
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
|
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
|
For details on the values of these keys see
|
|
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
|
test_cfg (dict): The testing config.
|
|
|
|
Returns:
|
|
Tensor: Output segmentation map.
|
|
"""
|
|
return self.forward(inputs, prev_output)
|