mmsegmentation/mmseg/models/segmentors/base.py

167 lines
6.3 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
2020-07-07 20:52:19 +08:00
from abc import ABCMeta, abstractmethod
from typing import List, Tuple
2020-07-07 20:52:19 +08:00
from mmengine.data import PixelData
from mmengine.model import BaseModel
from torch import Tensor
2020-07-07 20:52:19 +08:00
from mmseg.core import SegDataSample
from mmseg.core.utils import (ForwardResults, OptConfigType, OptMultiConfig,
OptSampleList, SampleList)
from mmseg.ops import resize
2020-07-07 20:52:19 +08:00
class BaseSegmentor(BaseModel, metaclass=ABCMeta):
"""Base class for segmentors.
Args:
data_preprocessor (dict, optional): Model preprocessing config
for processing the input data. it usually includes
``to_rgb``, ``pad_size_divisor``, ``pad_val``,
``mean`` and ``std``. Default to None.
init_cfg (dict, optional): the config to control the
initialization. Default to None.
"""
def __init__(self,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None):
super(BaseSegmentor, self).__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
@property
def with_neck(self) -> bool:
2020-07-07 20:52:19 +08:00
"""bool: whether the segmentor has neck"""
return hasattr(self, 'neck') and self.neck is not None
@property
def with_auxiliary_head(self) -> bool:
2020-07-07 20:52:19 +08:00
"""bool: whether the segmentor has auxiliary head"""
return hasattr(self,
'auxiliary_head') and self.auxiliary_head is not None
@property
def with_decode_head(self) -> bool:
2020-07-07 20:52:19 +08:00
"""bool: whether the segmentor has decode head"""
return hasattr(self, 'decode_head') and self.decode_head is not None
@abstractmethod
def extract_feat(self, batch_inputs: Tensor) -> bool:
2020-07-07 20:52:19 +08:00
"""Placeholder for extract features from images."""
pass
@abstractmethod
def encode_decode(self, batch_inputs: Tensor,
batch_data_samples: SampleList):
2020-07-07 20:52:19 +08:00
"""Placeholder for encode images with backbone and decode into a
semantic segmentation map of the same size as input."""
pass
def forward(self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None,
mode: str = 'tensor') -> ForwardResults:
"""The unified entry for a forward process in both training and test.
2020-07-07 20:52:19 +08:00
The method should accept three modes: "tensor", "predict" and "loss":
2020-07-07 20:52:19 +08:00
- "tensor": Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`SegDataSample`.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
2020-07-07 20:52:19 +08:00
Note that this method doesn't handle neither back propagation nor
optimizer updating, which are done in the :meth:`train_step`.
2020-07-07 20:52:19 +08:00
Args:
batch_inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
batch_data_samples (list[:obj:`SegDataSample`], optional): The
annotation data of every samples. Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
Returns:
The return type depends on ``mode``.
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="predict"``, return a list of :obj:`DetDataSample`.
- If ``mode="loss"``, return a dict of tensor.
2020-07-07 20:52:19 +08:00
"""
if mode == 'loss':
return self.loss(batch_inputs, batch_data_samples)
elif mode == 'predict':
return self.predict(batch_inputs, batch_data_samples)
elif mode == 'tensor':
return self._forward(batch_inputs, batch_data_samples)
2020-07-07 20:52:19 +08:00
else:
raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode')
@abstractmethod
def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and data samples."""
pass
@abstractmethod
def predict(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing."""
pass
@abstractmethod
def _forward(
self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
"""Network forward process.
Usually includes backbone, neck and head forward without any post-
processing.
"""
pass
2020-07-07 20:52:19 +08:00
@abstractmethod
def aug_test(self, batch_inputs, batch_img_metas):
"""Placeholder for augmentation test."""
pass
2020-07-07 20:52:19 +08:00
def postprocess_result(self, seg_logits_list: List[dict],
batch_img_metas: List[dict]) -> list:
""" Convert results list to `SegDataSample`.
2020-07-07 20:52:19 +08:00
Args:
seg_logits_list (List[dict]): List of segmentation results,
seg_logits from model of each input image.
2020-07-07 20:52:19 +08:00
Returns:
list[:obj:`SegDataSample`]: Segmentation results of the
input images. Each SegDataSample usually contain:
2020-07-07 20:52:19 +08:00
- ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
- ``seg_logits``(PixelData): Predicted logits of semantic
segmentation before normalization.
2020-07-07 20:52:19 +08:00
"""
predictions = []
for i in range(len(seg_logits_list)):
img_meta = batch_img_metas[i]
seg_logits = resize(
seg_logits_list[i][None],
size=img_meta['ori_shape'],
mode='bilinear',
align_corners=self.align_corners,
warning=False).squeeze(0)
# seg_logits shape is CHW
seg_pred = seg_logits.argmax(dim=0, keepdim=True)
prediction = SegDataSample(**{'metainfo': img_meta})
prediction.set_data({
'seg_logits': PixelData(**{'data': seg_logits}),
'pred_sem_seg': PixelData(**{'data': seg_pred})
})
predictions.append(prediction)
return predictions