2021-08-17 14:16:55 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
2020-07-07 20:52:19 +08:00
|
|
|
from abc import ABCMeta, abstractmethod
|
2022-06-19 14:32:09 +08:00
|
|
|
from typing import List
|
2020-07-07 20:52:19 +08:00
|
|
|
|
2022-06-19 14:32:09 +08:00
|
|
|
from torch import Tensor
|
|
|
|
|
2022-07-15 23:47:29 +08:00
|
|
|
from mmseg.utils import ConfigType
|
2020-07-07 20:52:19 +08:00
|
|
|
from .decode_head import BaseDecodeHead
|
|
|
|
|
|
|
|
|
|
|
|
class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
|
|
|
|
"""Base class for cascade decode head used in
|
|
|
|
:class:`CascadeEncoderDecoder."""
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def forward(self, inputs, prev_output):
|
|
|
|
"""Placeholder of forward function."""
|
|
|
|
pass
|
|
|
|
|
2022-06-19 14:32:09 +08:00
|
|
|
def loss(self, inputs: List[Tensor], prev_output: Tensor,
|
2022-06-22 16:24:13 +08:00
|
|
|
batch_data_samples: List[dict], train_cfg: ConfigType) -> Tensor:
|
2020-07-07 20:52:19 +08:00
|
|
|
"""Forward function for training.
|
2022-06-19 14:32:09 +08:00
|
|
|
|
2020-07-07 20:52:19 +08:00
|
|
|
Args:
|
2022-06-19 14:32:09 +08:00
|
|
|
inputs (List[Tensor]): List of multi-level img features.
|
2020-07-07 20:52:19 +08:00
|
|
|
prev_output (Tensor): The output of previous decode head.
|
2022-06-19 14:32:09 +08:00
|
|
|
batch_data_samples (List[:obj:`SegDataSample`]): The seg
|
2022-06-10 22:02:40 +08:00
|
|
|
data samples. It usually includes information such
|
2022-06-19 14:32:09 +08:00
|
|
|
as `metainfo` and `gt_sem_seg`.
|
2020-07-07 20:52:19 +08:00
|
|
|
train_cfg (dict): The training config.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict[str, Tensor]: a dictionary of loss components
|
|
|
|
"""
|
|
|
|
seg_logits = self.forward(inputs, prev_output)
|
2022-06-22 16:24:13 +08:00
|
|
|
losses = self.loss_by_feat(seg_logits, batch_data_samples)
|
2020-07-07 20:52:19 +08:00
|
|
|
|
|
|
|
return losses
|
|
|
|
|
2022-06-19 14:32:09 +08:00
|
|
|
def predict(self, inputs: List[Tensor], prev_output: Tensor,
|
2022-06-22 16:24:13 +08:00
|
|
|
batch_img_metas: List[dict], tese_cfg: ConfigType):
|
2020-07-07 20:52:19 +08:00
|
|
|
"""Forward function for testing.
|
|
|
|
|
|
|
|
Args:
|
2022-06-19 14:32:09 +08:00
|
|
|
inputs (List[Tensor]): List of multi-level img features.
|
2020-07-07 20:52:19 +08:00
|
|
|
prev_output (Tensor): The output of previous decode head.
|
2022-06-19 14:32:09 +08:00
|
|
|
batch_img_metas (dict): List Image info where each dict may also
|
|
|
|
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
|
|
|
'ori_shape', and 'pad_shape'.
|
|
|
|
For details on the values of these keys see
|
|
|
|
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
2020-07-07 20:52:19 +08:00
|
|
|
test_cfg (dict): The testing config.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor: Output segmentation map.
|
|
|
|
"""
|
2022-06-19 14:32:09 +08:00
|
|
|
seg_logits = self.forward(inputs, prev_output)
|
|
|
|
|
2022-06-22 16:24:13 +08:00
|
|
|
return self.predict_by_feat(seg_logits, batch_img_metas)
|