mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Fix the problem of post-processing not removing padding (#2367)
* add img_padding_size * minor change * add pad_shape to data_samples
This commit is contained in:
parent
90c816b6de
commit
925faea5bf
@ -137,12 +137,14 @@ class SegDataPreProcessor(BaseDataPreprocessor):
|
|||||||
'as the image size might be different in a batch')
|
'as the image size might be different in a batch')
|
||||||
# pad images when testing
|
# pad images when testing
|
||||||
if self.test_cfg:
|
if self.test_cfg:
|
||||||
inputs, _ = stack_batch(
|
inputs, padded_samples = stack_batch(
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
size=self.test_cfg.get('size', None),
|
size=self.test_cfg.get('size', None),
|
||||||
size_divisor=self.test_cfg.get('size_divisor', None),
|
size_divisor=self.test_cfg.get('size_divisor', None),
|
||||||
pad_val=self.pad_val,
|
pad_val=self.pad_val,
|
||||||
seg_pad_val=self.seg_pad_val)
|
seg_pad_val=self.seg_pad_val)
|
||||||
|
for data_sample, pad_info in zip(data_samples, padded_samples):
|
||||||
|
data_sample.set_metainfo({**pad_info})
|
||||||
else:
|
else:
|
||||||
inputs = torch.stack(inputs, dim=0)
|
inputs = torch.stack(inputs, dim=0)
|
||||||
|
|
||||||
|
@ -159,8 +159,12 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
|||||||
if not only_prediction:
|
if not only_prediction:
|
||||||
img_meta = data_samples[i].metainfo
|
img_meta = data_samples[i].metainfo
|
||||||
# remove padding area
|
# remove padding area
|
||||||
padding_left, padding_right, padding_top, padding_bottom = \
|
if 'img_padding_size' not in img_meta:
|
||||||
img_meta.get('padding_size', [0]*4)
|
padding_size = img_meta.get('padding_size', [0] * 4)
|
||||||
|
else:
|
||||||
|
padding_size = img_meta['img_padding_size']
|
||||||
|
padding_left, padding_right, padding_top, padding_bottom =\
|
||||||
|
padding_size
|
||||||
# i_seg_logits shape is 1, C, H, W after remove padding
|
# i_seg_logits shape is 1, C, H, W after remove padding
|
||||||
i_seg_logits = seg_logits[i:i + 1, :,
|
i_seg_logits = seg_logits[i:i + 1, :,
|
||||||
padding_top:H - padding_bottom,
|
padding_top:H - padding_bottom,
|
||||||
|
@ -105,6 +105,9 @@ def stack_batch(inputs: List[torch.Tensor],
|
|||||||
})
|
})
|
||||||
padded_samples.append(data_sample)
|
padded_samples.append(data_sample)
|
||||||
else:
|
else:
|
||||||
padded_samples = None
|
padded_samples.append(
|
||||||
|
dict(
|
||||||
|
img_padding_size=padding_size,
|
||||||
|
pad_shape=pad_img.shape[-2:]))
|
||||||
|
|
||||||
return torch.stack(padded_inputs, dim=0), padded_samples
|
return torch.stack(padded_inputs, dim=0), padded_samples
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import torch
|
||||||
from mmengine import ConfigDict
|
from mmengine import ConfigDict
|
||||||
|
from mmengine.structures import PixelData
|
||||||
|
|
||||||
from mmseg.models import build_segmentor
|
from mmseg.models import build_segmentor
|
||||||
|
from mmseg.structures import SegDataSample
|
||||||
from .utils import _segmentor_forward_train_test
|
from .utils import _segmentor_forward_train_test
|
||||||
|
|
||||||
|
|
||||||
@ -57,3 +59,42 @@ def test_encoder_decoder():
|
|||||||
cfg.test_cfg = ConfigDict(mode='whole')
|
cfg.test_cfg = ConfigDict(mode='whole')
|
||||||
segmentor = build_segmentor(cfg)
|
segmentor = build_segmentor(cfg)
|
||||||
_segmentor_forward_train_test(segmentor)
|
_segmentor_forward_train_test(segmentor)
|
||||||
|
|
||||||
|
|
||||||
|
def test_postprocess_result():
|
||||||
|
cfg = ConfigDict(
|
||||||
|
type='EncoderDecoder',
|
||||||
|
backbone=dict(type='ExampleBackbone'),
|
||||||
|
decode_head=dict(type='ExampleDecodeHead'),
|
||||||
|
train_cfg=None,
|
||||||
|
test_cfg=dict(mode='whole'))
|
||||||
|
model = build_segmentor(cfg)
|
||||||
|
|
||||||
|
# test postprocess
|
||||||
|
data_sample = SegDataSample()
|
||||||
|
data_sample.gt_sem_seg = PixelData(
|
||||||
|
**{'data': torch.randint(0, 10, (1, 8, 8))})
|
||||||
|
data_sample.set_metainfo({
|
||||||
|
'padding_size': (0, 2, 0, 2),
|
||||||
|
'ori_shape': (8, 8)
|
||||||
|
})
|
||||||
|
seg_logits = torch.zeros((1, 2, 10, 10))
|
||||||
|
seg_logits[:, :, :8, :8] = 1
|
||||||
|
data_samples = [data_sample]
|
||||||
|
|
||||||
|
outputs = model.postprocess_result(seg_logits, data_samples)
|
||||||
|
assert outputs[0].seg_logits.data.shape == torch.Size((2, 8, 8))
|
||||||
|
assert torch.allclose(outputs[0].seg_logits.data, torch.ones((2, 8, 8)))
|
||||||
|
|
||||||
|
data_sample = SegDataSample()
|
||||||
|
data_sample.gt_sem_seg = PixelData(
|
||||||
|
**{'data': torch.randint(0, 10, (1, 8, 8))})
|
||||||
|
data_sample.set_metainfo({
|
||||||
|
'img_padding_size': (0, 2, 0, 2),
|
||||||
|
'ori_shape': (8, 8)
|
||||||
|
})
|
||||||
|
|
||||||
|
data_samples = [data_sample]
|
||||||
|
outputs = model.postprocess_result(seg_logits, data_samples)
|
||||||
|
assert outputs[0].seg_logits.data.shape == torch.Size((2, 8, 8))
|
||||||
|
assert torch.allclose(outputs[0].seg_logits.data, torch.ones((2, 8, 8)))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user