mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Merge pull request #2290 from xiexinch/fix_gt_padding
[Enhancement] Support padding in testing
This commit is contained in:
commit
c56a299571
@ -48,9 +48,13 @@ class SegDataPreProcessor(BaseDataPreprocessor):
|
|||||||
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
|
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
batch_augments (list[dict], optional): Batch-level augmentations
|
batch_augments (list[dict], optional): Batch-level augmentations
|
||||||
|
test_cfg (dict, optional): The padding size config in testing, if not
|
||||||
|
specify, will use `size` and `size_divisor` params as default.
|
||||||
|
Defaults to None, only supports keys `size` or `size_divisor`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
mean: Sequence[Number] = None,
|
mean: Sequence[Number] = None,
|
||||||
std: Sequence[Number] = None,
|
std: Sequence[Number] = None,
|
||||||
size: Optional[tuple] = None,
|
size: Optional[tuple] = None,
|
||||||
@ -59,7 +63,9 @@ class SegDataPreProcessor(BaseDataPreprocessor):
|
|||||||
seg_pad_val: Number = 255,
|
seg_pad_val: Number = 255,
|
||||||
bgr_to_rgb: bool = False,
|
bgr_to_rgb: bool = False,
|
||||||
rgb_to_bgr: bool = False,
|
rgb_to_bgr: bool = False,
|
||||||
batch_augments: Optional[List[dict]] = None):
|
batch_augments: Optional[List[dict]] = None,
|
||||||
|
test_cfg: dict = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.size = size
|
self.size = size
|
||||||
self.size_divisor = size_divisor
|
self.size_divisor = size_divisor
|
||||||
@ -86,6 +92,9 @@ class SegDataPreProcessor(BaseDataPreprocessor):
|
|||||||
# TODO: support batch augmentations.
|
# TODO: support batch augmentations.
|
||||||
self.batch_augments = batch_augments
|
self.batch_augments = batch_augments
|
||||||
|
|
||||||
|
# Support different padding methods in testing
|
||||||
|
self.test_cfg = test_cfg
|
||||||
|
|
||||||
def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
|
def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
|
||||||
"""Perform normalization、padding and bgr2rgb conversion based on
|
"""Perform normalization、padding and bgr2rgb conversion based on
|
||||||
``BaseDataPreprocessor``.
|
``BaseDataPreprocessor``.
|
||||||
@ -122,10 +131,19 @@ class SegDataPreProcessor(BaseDataPreprocessor):
|
|||||||
if self.batch_augments is not None:
|
if self.batch_augments is not None:
|
||||||
inputs, data_samples = self.batch_augments(
|
inputs, data_samples = self.batch_augments(
|
||||||
inputs, data_samples)
|
inputs, data_samples)
|
||||||
return dict(inputs=inputs, data_samples=data_samples)
|
|
||||||
else:
|
else:
|
||||||
assert len(inputs) == 1, (
|
assert len(inputs) == 1, (
|
||||||
'Batch inference is not support currently, '
|
'Batch inference is not support currently, '
|
||||||
'as the image size might be different in a batch')
|
'as the image size might be different in a batch')
|
||||||
return dict(
|
# pad images when testing
|
||||||
inputs=torch.stack(inputs, dim=0), data_samples=data_samples)
|
if self.test_cfg:
|
||||||
|
inputs, _ = stack_batch(
|
||||||
|
inputs=inputs,
|
||||||
|
size=self.test_cfg.get('size', None),
|
||||||
|
size_divisor=self.test_cfg.get('size_divisor', None),
|
||||||
|
pad_val=self.pad_val,
|
||||||
|
seg_pad_val=self.seg_pad_val)
|
||||||
|
else:
|
||||||
|
inputs = torch.stack(inputs, dim=0)
|
||||||
|
|
||||||
|
return dict(inputs=inputs, data_samples=data_samples)
|
||||||
|
@ -165,6 +165,7 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
|||||||
i_seg_logits = seg_logits[i:i + 1, :,
|
i_seg_logits = seg_logits[i:i + 1, :,
|
||||||
padding_top:H - padding_bottom,
|
padding_top:H - padding_bottom,
|
||||||
padding_left:W - padding_right]
|
padding_left:W - padding_right]
|
||||||
|
|
||||||
# resize as original shape
|
# resize as original shape
|
||||||
i_seg_logits = resize(
|
i_seg_logits = resize(
|
||||||
i_seg_logits,
|
i_seg_logits,
|
||||||
|
@ -46,3 +46,19 @@ class TestSegDataPreProcessor(TestCase):
|
|||||||
out = processor(data, training=True)
|
out = processor(data, training=True)
|
||||||
self.assertEqual(out['inputs'].shape, (2, 3, 20, 20))
|
self.assertEqual(out['inputs'].shape, (2, 3, 20, 20))
|
||||||
self.assertEqual(len(out['data_samples']), 2)
|
self.assertEqual(len(out['data_samples']), 2)
|
||||||
|
|
||||||
|
# test predict with padding
|
||||||
|
processor = SegDataPreProcessor(
|
||||||
|
mean=[0, 0, 0],
|
||||||
|
std=[1, 1, 1],
|
||||||
|
size=(20, 20),
|
||||||
|
test_cfg=dict(size_divisor=15))
|
||||||
|
data = {
|
||||||
|
'inputs': [
|
||||||
|
torch.randint(0, 256, (3, 11, 10)),
|
||||||
|
],
|
||||||
|
'data_samples': [data_sample]
|
||||||
|
}
|
||||||
|
out = processor(data, training=False)
|
||||||
|
self.assertEqual(out['inputs'].shape[2] % 15, 0)
|
||||||
|
self.assertEqual(out['inputs'].shape[3] % 15, 0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user