[Fix] Fix the problem of post-processing not removing padding (#2367)

* add img_padding_size

* minor change

* add pad_shape to data_samples
pull/2229/head
谢昕辰 2022-12-01 16:35:39 +08:00 committed by GitHub
parent 90c816b6de
commit 925faea5bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 55 additions and 5 deletions

View File

@ -137,12 +137,14 @@ class SegDataPreProcessor(BaseDataPreprocessor):
'as the image size might be different in a batch')
# pad images when testing
if self.test_cfg:
inputs, _ = stack_batch(
inputs, padded_samples = stack_batch(
inputs=inputs,
size=self.test_cfg.get('size', None),
size_divisor=self.test_cfg.get('size_divisor', None),
pad_val=self.pad_val,
seg_pad_val=self.seg_pad_val)
for data_sample, pad_info in zip(data_samples, padded_samples):
data_sample.set_metainfo({**pad_info})
else:
inputs = torch.stack(inputs, dim=0)

View File

@ -159,8 +159,12 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
if not only_prediction:
img_meta = data_samples[i].metainfo
# remove padding area
padding_left, padding_right, padding_top, padding_bottom = \
img_meta.get('padding_size', [0]*4)
if 'img_padding_size' not in img_meta:
padding_size = img_meta.get('padding_size', [0] * 4)
else:
padding_size = img_meta['img_padding_size']
padding_left, padding_right, padding_top, padding_bottom =\
padding_size
# i_seg_logits shape is 1, C, H, W after remove padding
i_seg_logits = seg_logits[i:i + 1, :,
padding_top:H - padding_bottom,

View File

@ -105,6 +105,9 @@ def stack_batch(inputs: List[torch.Tensor],
})
padded_samples.append(data_sample)
else:
padded_samples = None
padded_samples.append(
dict(
img_padding_size=padding_size,
pad_shape=pad_img.shape[-2:]))
return torch.stack(padded_inputs, dim=0), padded_samples

View File

@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine import ConfigDict
from mmengine.structures import PixelData
from mmseg.models import build_segmentor
from mmseg.structures import SegDataSample
from .utils import _segmentor_forward_train_test
@ -57,3 +59,42 @@ def test_encoder_decoder():
cfg.test_cfg = ConfigDict(mode='whole')
segmentor = build_segmentor(cfg)
_segmentor_forward_train_test(segmentor)
def test_postprocess_result():
cfg = ConfigDict(
type='EncoderDecoder',
backbone=dict(type='ExampleBackbone'),
decode_head=dict(type='ExampleDecodeHead'),
train_cfg=None,
test_cfg=dict(mode='whole'))
model = build_segmentor(cfg)
# test postprocess
data_sample = SegDataSample()
data_sample.gt_sem_seg = PixelData(
**{'data': torch.randint(0, 10, (1, 8, 8))})
data_sample.set_metainfo({
'padding_size': (0, 2, 0, 2),
'ori_shape': (8, 8)
})
seg_logits = torch.zeros((1, 2, 10, 10))
seg_logits[:, :, :8, :8] = 1
data_samples = [data_sample]
outputs = model.postprocess_result(seg_logits, data_samples)
assert outputs[0].seg_logits.data.shape == torch.Size((2, 8, 8))
assert torch.allclose(outputs[0].seg_logits.data, torch.ones((2, 8, 8)))
data_sample = SegDataSample()
data_sample.gt_sem_seg = PixelData(
**{'data': torch.randint(0, 10, (1, 8, 8))})
data_sample.set_metainfo({
'img_padding_size': (0, 2, 0, 2),
'ori_shape': (8, 8)
})
data_samples = [data_sample]
outputs = model.postprocess_result(seg_logits, data_samples)
assert outputs[0].seg_logits.data.shape == torch.Size((2, 8, 8))
assert torch.allclose(outputs[0].seg_logits.data, torch.ones((2, 8, 8)))