mmsegmentation/mmseg/models/decode_heads/cascade_decode_head.py

64 lines
2.3 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
2020-07-07 20:52:19 +08:00
from abc import ABCMeta, abstractmethod
from typing import List
2020-07-07 20:52:19 +08:00
from torch import Tensor
from mmseg.core.utils import ConfigType
2020-07-07 20:52:19 +08:00
from .decode_head import BaseDecodeHead
class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
"""Base class for cascade decode head used in
:class:`CascadeEncoderDecoder."""
def __init__(self, *args, **kwargs):
super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs)
@abstractmethod
def forward(self, inputs, prev_output):
"""Placeholder of forward function."""
pass
def loss(self, inputs: List[Tensor], prev_output: Tensor,
batch_data_samples: List[dict], train_cfg: ConfigType,
**kwargs) -> Tensor:
2020-07-07 20:52:19 +08:00
"""Forward function for training.
2020-07-07 20:52:19 +08:00
Args:
inputs (List[Tensor]): List of multi-level img features.
2020-07-07 20:52:19 +08:00
prev_output (Tensor): The output of previous decode head.
batch_data_samples (List[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_sem_seg`.
2020-07-07 20:52:19 +08:00
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs, prev_output)
losses = self.loss_by_feat(seg_logits, batch_data_samples, **kwargs)
2020-07-07 20:52:19 +08:00
return losses
def predict(self, inputs: List[Tensor], prev_output: Tensor,
batch_img_metas: List[dict], tese_cfg: ConfigType, **kwargs):
2020-07-07 20:52:19 +08:00
"""Forward function for testing.
Args:
inputs (List[Tensor]): List of multi-level img features.
2020-07-07 20:52:19 +08:00
prev_output (Tensor): The output of previous decode head.
batch_img_metas (dict): List Image info where each dict may also
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', and 'pad_shape'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
2020-07-07 20:52:19 +08:00
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
seg_logits = self.forward(inputs, prev_output)
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)