[Fix] Fix the problem of post-processing not removing padding (#2367)
* add img_padding_size * minor change * add pad_shape to data_samplespull/2229/head
parent
90c816b6de
commit
925faea5bf
|
@ -137,12 +137,14 @@ class SegDataPreProcessor(BaseDataPreprocessor):
|
|||
'as the image size might be different in a batch')
|
||||
# pad images when testing
|
||||
if self.test_cfg:
|
||||
inputs, _ = stack_batch(
|
||||
inputs, padded_samples = stack_batch(
|
||||
inputs=inputs,
|
||||
size=self.test_cfg.get('size', None),
|
||||
size_divisor=self.test_cfg.get('size_divisor', None),
|
||||
pad_val=self.pad_val,
|
||||
seg_pad_val=self.seg_pad_val)
|
||||
for data_sample, pad_info in zip(data_samples, padded_samples):
|
||||
data_sample.set_metainfo({**pad_info})
|
||||
else:
|
||||
inputs = torch.stack(inputs, dim=0)
|
||||
|
||||
|
|
|
@ -159,8 +159,12 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
|||
if not only_prediction:
|
||||
img_meta = data_samples[i].metainfo
|
||||
# remove padding area
|
||||
padding_left, padding_right, padding_top, padding_bottom = \
|
||||
img_meta.get('padding_size', [0]*4)
|
||||
if 'img_padding_size' not in img_meta:
|
||||
padding_size = img_meta.get('padding_size', [0] * 4)
|
||||
else:
|
||||
padding_size = img_meta['img_padding_size']
|
||||
padding_left, padding_right, padding_top, padding_bottom =\
|
||||
padding_size
|
||||
# i_seg_logits shape is 1, C, H, W after remove padding
|
||||
i_seg_logits = seg_logits[i:i + 1, :,
|
||||
padding_top:H - padding_bottom,
|
||||
|
|
|
@ -105,6 +105,9 @@ def stack_batch(inputs: List[torch.Tensor],
|
|||
})
|
||||
padded_samples.append(data_sample)
|
||||
else:
|
||||
padded_samples = None
|
||||
padded_samples.append(
|
||||
dict(
|
||||
img_padding_size=padding_size,
|
||||
pad_shape=pad_img.shape[-2:]))
|
||||
|
||||
return torch.stack(padded_inputs, dim=0), padded_samples
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import torch
|
||||
from mmengine import ConfigDict
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.models import build_segmentor
|
||||
from mmseg.structures import SegDataSample
|
||||
from .utils import _segmentor_forward_train_test
|
||||
|
||||
|
||||
|
@ -57,3 +59,42 @@ def test_encoder_decoder():
|
|||
cfg.test_cfg = ConfigDict(mode='whole')
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
|
||||
|
||||
def test_postprocess_result():
|
||||
cfg = ConfigDict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=dict(type='ExampleDecodeHead'),
|
||||
train_cfg=None,
|
||||
test_cfg=dict(mode='whole'))
|
||||
model = build_segmentor(cfg)
|
||||
|
||||
# test postprocess
|
||||
data_sample = SegDataSample()
|
||||
data_sample.gt_sem_seg = PixelData(
|
||||
**{'data': torch.randint(0, 10, (1, 8, 8))})
|
||||
data_sample.set_metainfo({
|
||||
'padding_size': (0, 2, 0, 2),
|
||||
'ori_shape': (8, 8)
|
||||
})
|
||||
seg_logits = torch.zeros((1, 2, 10, 10))
|
||||
seg_logits[:, :, :8, :8] = 1
|
||||
data_samples = [data_sample]
|
||||
|
||||
outputs = model.postprocess_result(seg_logits, data_samples)
|
||||
assert outputs[0].seg_logits.data.shape == torch.Size((2, 8, 8))
|
||||
assert torch.allclose(outputs[0].seg_logits.data, torch.ones((2, 8, 8)))
|
||||
|
||||
data_sample = SegDataSample()
|
||||
data_sample.gt_sem_seg = PixelData(
|
||||
**{'data': torch.randint(0, 10, (1, 8, 8))})
|
||||
data_sample.set_metainfo({
|
||||
'img_padding_size': (0, 2, 0, 2),
|
||||
'ori_shape': (8, 8)
|
||||
})
|
||||
|
||||
data_samples = [data_sample]
|
||||
outputs = model.postprocess_result(seg_logits, data_samples)
|
||||
assert outputs[0].seg_logits.data.shape == torch.Size((2, 8, 8))
|
||||
assert torch.allclose(outputs[0].seg_logits.data, torch.ones((2, 8, 8)))
|
||||
|
|
Loading…
Reference in New Issue