1
0
mirror of https://github.com/open-mmlab/mmsegmentation.git synced 2025-06-03 22:03:48 +08:00

Merge pull request from xiexinch/fix_gt_padding

[Enhancement] Support padding in testing
This commit is contained in:
Miao Zheng 2022-11-19 18:24:59 +08:00 committed by GitHub
commit c56a299571
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 13 deletions
mmseg/models
tests/test_models

@ -48,9 +48,13 @@ class SegDataPreProcessor(BaseDataPreprocessor):
rgb_to_bgr (bool): whether to convert image from RGB to RGB. rgb_to_bgr (bool): whether to convert image from RGB to RGB.
Defaults to False. Defaults to False.
batch_augments (list[dict], optional): Batch-level augmentations batch_augments (list[dict], optional): Batch-level augmentations
test_cfg (dict, optional): The padding size config in testing, if not
specify, will use `size` and `size_divisor` params as default.
Defaults to None, only supports keys `size` or `size_divisor`.
""" """
def __init__(self, def __init__(
self,
mean: Sequence[Number] = None, mean: Sequence[Number] = None,
std: Sequence[Number] = None, std: Sequence[Number] = None,
size: Optional[tuple] = None, size: Optional[tuple] = None,
@ -59,7 +63,9 @@ class SegDataPreProcessor(BaseDataPreprocessor):
seg_pad_val: Number = 255, seg_pad_val: Number = 255,
bgr_to_rgb: bool = False, bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False, rgb_to_bgr: bool = False,
batch_augments: Optional[List[dict]] = None): batch_augments: Optional[List[dict]] = None,
test_cfg: dict = None,
):
super().__init__() super().__init__()
self.size = size self.size = size
self.size_divisor = size_divisor self.size_divisor = size_divisor
@ -86,6 +92,9 @@ class SegDataPreProcessor(BaseDataPreprocessor):
# TODO: support batch augmentations. # TODO: support batch augmentations.
self.batch_augments = batch_augments self.batch_augments = batch_augments
# Support different padding methods in testing
self.test_cfg = test_cfg
def forward(self, data: dict, training: bool = False) -> Dict[str, Any]: def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
"""Perform normalization、padding and bgr2rgb conversion based on """Perform normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``. ``BaseDataPreprocessor``.
@ -122,10 +131,19 @@ class SegDataPreProcessor(BaseDataPreprocessor):
if self.batch_augments is not None: if self.batch_augments is not None:
inputs, data_samples = self.batch_augments( inputs, data_samples = self.batch_augments(
inputs, data_samples) inputs, data_samples)
return dict(inputs=inputs, data_samples=data_samples)
else: else:
assert len(inputs) == 1, ( assert len(inputs) == 1, (
'Batch inference is not support currently, ' 'Batch inference is not support currently, '
'as the image size might be different in a batch') 'as the image size might be different in a batch')
return dict( # pad images when testing
inputs=torch.stack(inputs, dim=0), data_samples=data_samples) if self.test_cfg:
inputs, _ = stack_batch(
inputs=inputs,
size=self.test_cfg.get('size', None),
size_divisor=self.test_cfg.get('size_divisor', None),
pad_val=self.pad_val,
seg_pad_val=self.seg_pad_val)
else:
inputs = torch.stack(inputs, dim=0)
return dict(inputs=inputs, data_samples=data_samples)

@ -165,6 +165,7 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
i_seg_logits = seg_logits[i:i + 1, :, i_seg_logits = seg_logits[i:i + 1, :,
padding_top:H - padding_bottom, padding_top:H - padding_bottom,
padding_left:W - padding_right] padding_left:W - padding_right]
# resize as original shape # resize as original shape
i_seg_logits = resize( i_seg_logits = resize(
i_seg_logits, i_seg_logits,

@ -46,3 +46,19 @@ class TestSegDataPreProcessor(TestCase):
out = processor(data, training=True) out = processor(data, training=True)
self.assertEqual(out['inputs'].shape, (2, 3, 20, 20)) self.assertEqual(out['inputs'].shape, (2, 3, 20, 20))
self.assertEqual(len(out['data_samples']), 2) self.assertEqual(len(out['data_samples']), 2)
# test predict with padding
processor = SegDataPreProcessor(
mean=[0, 0, 0],
std=[1, 1, 1],
size=(20, 20),
test_cfg=dict(size_divisor=15))
data = {
'inputs': [
torch.randint(0, 256, (3, 11, 10)),
],
'data_samples': [data_sample]
}
out = processor(data, training=False)
self.assertEqual(out['inputs'].shape[2] % 15, 0)
self.assertEqual(out['inputs'].shape[3] % 15, 0)