mmsegmentation/mmseg/models/decode_heads/cascade_decode_head.py

52 lines
1.8 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from .decode_head import BaseDecodeHead
class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
"""Base class for cascade decode head used in
:class:`CascadeEncoderDecoder."""
def __init__(self, *args, **kwargs):
super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs)
@abstractmethod
def forward(self, inputs, prev_output):
"""Placeholder of forward function."""
pass
def forward_train(self, inputs, prev_output, batch_data_samples,
train_cfg):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
batch_data_samples (list[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `img_metas` or `gt_semantic_seg`.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs, prev_output)
losses = self.losses(seg_logits, batch_data_samples)
return losses
def forward_test(self, inputs, prev_output, batch_img_metas, test_cfg):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
return self.forward(inputs, prev_output)