mmsegmentation/mmseg/models/data_preprocessor.py
谢昕辰 857f854b61
[Enhancement] Remove batch inference assertion (#3210)
Thanks for your contribution and we appreciate it a lot. The following
instructions would make your pull request more healthy and more easily
get feedback. If you do not understand some items, don't worry, just
make the pull request and seek help from maintainers.

## Motivation

https://github.com/open-mmlab/mmsegmentation/issues/3181
https://github.com/open-mmlab/mmsegmentation/issues/2965
https://github.com/open-mmlab/mmsegmentation/issues/2644
https://github.com/open-mmlab/mmsegmentation/issues/1645
https://github.com/open-mmlab/mmsegmentation/issues/1444
https://github.com/open-mmlab/mmsegmentation/issues/1370
https://github.com/open-mmlab/mmsegmentation/issues/125

## Modification

Remove the assertion at data_preprocessor

## BC-breaking (Optional)

Does the modification introduce changes that break the
backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the
downstream projects should modify their code to keep compatibility with
this PR.

## Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases
here, and update the documentation.

## Checklist

1. Pre-commit or other linting tools are used to fix the potential lint
issues.
2. The modification is covered by complete unit tests. If not, please
add more unit test to ensure the correctness.
3. If the modification has potential influence on downstream projects,
this PR should be tested with downstream projects, like MMDet or
MMDet3D.
4. The documentation has been modified accordingly, like docstring or
example tutorials.
2023-07-20 09:45:04 +08:00

152 lines
6.1 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from numbers import Number
from typing import Any, Dict, List, Optional, Sequence
import torch
from mmengine.model import BaseDataPreprocessor
from mmseg.registry import MODELS
from mmseg.utils import stack_batch
@MODELS.register_module()
class SegDataPreProcessor(BaseDataPreprocessor):
"""Image pre-processor for segmentation tasks.
Comparing with the :class:`mmengine.ImgDataPreprocessor`,
1. It won't do normalization if ``mean`` is not specified.
2. It does normalization and color space conversion after stacking batch.
3. It supports batch augmentations like mixup and cutmix.
It provides the data pre-processing as follows
- Collate and move data to the target device.
- Pad inputs to the input size with defined ``pad_val``, and pad seg map
with defined ``seg_pad_val``.
- Stack inputs to batch_inputs.
- Convert inputs from bgr to rgb if the shape of input is (3, H, W).
- Normalize image with defined std and mean.
- Do batch augmentations like Mixup and Cutmix during training.
Args:
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
Defaults to None.
std (Sequence[Number], optional): The pixel standard deviation of
R, G, B channels. Defaults to None.
size (tuple, optional): Fixed padding size.
size_divisor (int, optional): The divisor of padded size.
pad_val (float, optional): Padding value. Default: 0.
seg_pad_val (float, optional): Padding value of segmentation map.
Default: 255.
padding_mode (str): Type of padding. Default: constant.
- constant: pads with a constant value, this value is specified
with pad_val.
bgr_to_rgb (bool): whether to convert image from BGR to RGB.
Defaults to False.
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
Defaults to False.
batch_augments (list[dict], optional): Batch-level augmentations
test_cfg (dict, optional): The padding size config in testing, if not
specify, will use `size` and `size_divisor` params as default.
Defaults to None, only supports keys `size` or `size_divisor`.
"""
def __init__(
self,
mean: Sequence[Number] = None,
std: Sequence[Number] = None,
size: Optional[tuple] = None,
size_divisor: Optional[int] = None,
pad_val: Number = 0,
seg_pad_val: Number = 255,
bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False,
batch_augments: Optional[List[dict]] = None,
test_cfg: dict = None,
):
super().__init__()
self.size = size
self.size_divisor = size_divisor
self.pad_val = pad_val
self.seg_pad_val = seg_pad_val
assert not (bgr_to_rgb and rgb_to_bgr), (
'`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time')
self.channel_conversion = rgb_to_bgr or bgr_to_rgb
if mean is not None:
assert std is not None, 'To enable the normalization in ' \
'preprocessing, please specify both ' \
'`mean` and `std`.'
# Enable the normalization in preprocessing.
self._enable_normalize = True
self.register_buffer('mean',
torch.tensor(mean).view(-1, 1, 1), False)
self.register_buffer('std',
torch.tensor(std).view(-1, 1, 1), False)
else:
self._enable_normalize = False
# TODO: support batch augmentations.
self.batch_augments = batch_augments
# Support different padding methods in testing
self.test_cfg = test_cfg
def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
"""Perform normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (dict): data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
Dict: Data in the same format as the model input.
"""
data = self.cast_data(data) # type: ignore
inputs = data['inputs']
data_samples = data.get('data_samples', None)
# TODO: whether normalize should be after stack_batch
if self.channel_conversion and inputs[0].size(0) == 3:
inputs = [_input[[2, 1, 0], ...] for _input in inputs]
inputs = [_input.float() for _input in inputs]
if self._enable_normalize:
inputs = [(_input - self.mean) / self.std for _input in inputs]
if training:
assert data_samples is not None, ('During training, ',
'`data_samples` must be define.')
inputs, data_samples = stack_batch(
inputs=inputs,
data_samples=data_samples,
size=self.size,
size_divisor=self.size_divisor,
pad_val=self.pad_val,
seg_pad_val=self.seg_pad_val)
if self.batch_augments is not None:
inputs, data_samples = self.batch_augments(
inputs, data_samples)
else:
img_size = inputs[0].shape[1:]
assert all(input_.shape[1:] == img_size for input_ in inputs), \
'The image size in a batch should be the same.'
# pad images when testing
if self.test_cfg:
inputs, padded_samples = stack_batch(
inputs=inputs,
size=self.test_cfg.get('size', None),
size_divisor=self.test_cfg.get('size_divisor', None),
pad_val=self.pad_val,
seg_pad_val=self.seg_pad_val)
for data_sample, pad_info in zip(data_samples, padded_samples):
data_sample.set_metainfo({**pad_info})
else:
inputs = torch.stack(inputs, dim=0)
return dict(inputs=inputs, data_samples=data_samples)