167 lines
6.3 KiB
Python
167 lines
6.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from abc import ABCMeta, abstractmethod
|
|
from typing import List, Tuple
|
|
|
|
from mmengine.data import PixelData
|
|
from mmengine.model import BaseModel
|
|
from torch import Tensor
|
|
|
|
from mmseg.core import SegDataSample
|
|
from mmseg.core.utils import (ForwardResults, OptConfigType, OptMultiConfig,
|
|
OptSampleList, SampleList)
|
|
from mmseg.ops import resize
|
|
|
|
|
|
class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
|
"""Base class for segmentors.
|
|
|
|
Args:
|
|
data_preprocessor (dict, optional): Model preprocessing config
|
|
for processing the input data. it usually includes
|
|
``to_rgb``, ``pad_size_divisor``, ``pad_val``,
|
|
``mean`` and ``std``. Default to None.
|
|
init_cfg (dict, optional): the config to control the
|
|
initialization. Default to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
data_preprocessor: OptConfigType = None,
|
|
init_cfg: OptMultiConfig = None):
|
|
super(BaseSegmentor, self).__init__(
|
|
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
|
|
|
|
@property
|
|
def with_neck(self) -> bool:
|
|
"""bool: whether the segmentor has neck"""
|
|
return hasattr(self, 'neck') and self.neck is not None
|
|
|
|
@property
|
|
def with_auxiliary_head(self) -> bool:
|
|
"""bool: whether the segmentor has auxiliary head"""
|
|
return hasattr(self,
|
|
'auxiliary_head') and self.auxiliary_head is not None
|
|
|
|
@property
|
|
def with_decode_head(self) -> bool:
|
|
"""bool: whether the segmentor has decode head"""
|
|
return hasattr(self, 'decode_head') and self.decode_head is not None
|
|
|
|
@abstractmethod
|
|
def extract_feat(self, batch_inputs: Tensor) -> bool:
|
|
"""Placeholder for extract features from images."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def encode_decode(self, batch_inputs: Tensor,
|
|
batch_data_samples: SampleList):
|
|
"""Placeholder for encode images with backbone and decode into a
|
|
semantic segmentation map of the same size as input."""
|
|
pass
|
|
|
|
def forward(self,
|
|
batch_inputs: Tensor,
|
|
batch_data_samples: OptSampleList = None,
|
|
mode: str = 'tensor') -> ForwardResults:
|
|
"""The unified entry for a forward process in both training and test.
|
|
|
|
The method should accept three modes: "tensor", "predict" and "loss":
|
|
|
|
- "tensor": Forward the whole network and return tensor or tuple of
|
|
tensor without any post-processing, same as a common nn.Module.
|
|
- "predict": Forward and return the predictions, which are fully
|
|
processed to a list of :obj:`SegDataSample`.
|
|
- "loss": Forward and return a dict of losses according to the given
|
|
inputs and data samples.
|
|
|
|
Note that this method doesn't handle neither back propagation nor
|
|
optimizer updating, which are done in the :meth:`train_step`.
|
|
|
|
Args:
|
|
batch_inputs (torch.Tensor): The input tensor with shape
|
|
(N, C, ...) in general.
|
|
batch_data_samples (list[:obj:`SegDataSample`], optional): The
|
|
annotation data of every samples. Defaults to None.
|
|
mode (str): Return what kind of value. Defaults to 'tensor'.
|
|
|
|
Returns:
|
|
The return type depends on ``mode``.
|
|
|
|
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
|
|
- If ``mode="predict"``, return a list of :obj:`DetDataSample`.
|
|
- If ``mode="loss"``, return a dict of tensor.
|
|
"""
|
|
if mode == 'loss':
|
|
return self.loss(batch_inputs, batch_data_samples)
|
|
elif mode == 'predict':
|
|
return self.predict(batch_inputs, batch_data_samples)
|
|
elif mode == 'tensor':
|
|
return self._forward(batch_inputs, batch_data_samples)
|
|
else:
|
|
raise RuntimeError(f'Invalid mode "{mode}". '
|
|
'Only supports loss, predict and tensor mode')
|
|
|
|
@abstractmethod
|
|
def loss(self, batch_inputs: Tensor,
|
|
batch_data_samples: SampleList) -> dict:
|
|
"""Calculate losses from a batch of inputs and data samples."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def predict(self, batch_inputs: Tensor,
|
|
batch_data_samples: SampleList) -> SampleList:
|
|
"""Predict results from a batch of inputs and data samples with post-
|
|
processing."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _forward(
|
|
self,
|
|
batch_inputs: Tensor,
|
|
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
|
|
"""Network forward process.
|
|
|
|
Usually includes backbone, neck and head forward without any post-
|
|
processing.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def aug_test(self, batch_inputs, batch_img_metas):
|
|
"""Placeholder for augmentation test."""
|
|
pass
|
|
|
|
def postprocess_result(self, seg_logits_list: List[dict],
|
|
batch_img_metas: List[dict]) -> list:
|
|
""" Convert results list to `SegDataSample`.
|
|
Args:
|
|
seg_logits_list (List[dict]): List of segmentation results,
|
|
seg_logits from model of each input image.
|
|
|
|
Returns:
|
|
list[:obj:`SegDataSample`]: Segmentation results of the
|
|
input images. Each SegDataSample usually contain:
|
|
|
|
- ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
|
|
- ``seg_logits``(PixelData): Predicted logits of semantic
|
|
segmentation before normalization.
|
|
"""
|
|
predictions = []
|
|
|
|
for i in range(len(seg_logits_list)):
|
|
img_meta = batch_img_metas[i]
|
|
seg_logits = resize(
|
|
seg_logits_list[i][None],
|
|
size=img_meta['ori_shape'],
|
|
mode='bilinear',
|
|
align_corners=self.align_corners,
|
|
warning=False).squeeze(0)
|
|
# seg_logits shape is CHW
|
|
seg_pred = seg_logits.argmax(dim=0, keepdim=True)
|
|
prediction = SegDataSample(**{'metainfo': img_meta})
|
|
prediction.set_data({
|
|
'seg_logits': PixelData(**{'data': seg_logits}),
|
|
'pred_sem_seg': PixelData(**{'data': seg_pred})
|
|
})
|
|
predictions.append(prediction)
|
|
return predictions
|