From 6f43f4d5d3dbfa8f10e54be8f440e68b5131a19d Mon Sep 17 00:00:00 2001 From: FreyWang Date: Wed, 15 Jun 2022 11:28:09 +0800 Subject: [PATCH] [Enchance] support infererence with padding (#1607) * [Enchance] support infererence with padding * limite pad after flip when inference * add test code --- mmseg/datasets/pipelines/test_time_aug.py | 8 +++++ mmseg/models/segmentors/encoder_decoder.py | 6 ++++ tests/test_data/test_tta.py | 38 ++++++++++++++++++++++ 3 files changed, 52 insertions(+) diff --git a/mmseg/datasets/pipelines/test_time_aug.py b/mmseg/datasets/pipelines/test_time_aug.py index 5c17cbbba..49640879f 100644 --- a/mmseg/datasets/pipelines/test_time_aug.py +++ b/mmseg/datasets/pipelines/test_time_aug.py @@ -57,6 +57,14 @@ class MultiScaleFlipAug(object): img_ratios=None, flip=False, flip_direction='horizontal'): + if flip: + trans_index = { + key['type']: index + for index, key in enumerate(transforms) + } + if 'RandomFlip' in trans_index and 'Pad' in trans_index: + assert trans_index['RandomFlip'] < trans_index['Pad'], \ + 'Pad must be executed after RandomFlip when flip is True' self.transforms = Compose(transforms) if img_ratios is not None: img_ratios = img_ratios if isinstance(img_ratios, diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index 72467b469..d94a3739e 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -189,6 +189,9 @@ class EncoderDecoder(BaseSegmentor): count_mat.cpu().detach().numpy()).to(device=img.device) preds = preds / count_mat if rescale: + # remove padding area + resize_shape = img_meta[0]['img_shape'][:2] + preds = preds[:, :, :resize_shape[0], :resize_shape[1]] preds = resize( preds, size=img_meta[0]['ori_shape'][:2], @@ -206,6 +209,9 @@ class EncoderDecoder(BaseSegmentor): if torch.onnx.is_in_onnx_export(): size = img.shape[2:] else: + # remove padding area + resize_shape = img_meta[0]['img_shape'][:2] + seg_logit = seg_logit[:, :, :resize_shape[0], :resize_shape[1]] size = img_meta[0]['ori_shape'][:2] seg_logit = resize( seg_logit, diff --git a/tests/test_data/test_tta.py b/tests/test_data/test_tta.py index d61af27ae..9373e2b62 100644 --- a/tests/test_data/test_tta.py +++ b/tests/test_data/test_tta.py @@ -149,3 +149,41 @@ def test_multi_scale_flip_aug(): assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512), (512, 512), (1024, 1024), (1024, 1024)] assert tta_results['flip'] == [False, True, False, True, False, True] + + # test assertion if flip is True and Pad executed before RandomFlip + with pytest.raises(AssertionError): + tta_transform = dict( + type='MultiScaleFlipAug', + img_scale=[(256, 256), (512, 512), (1024, 1024)], + img_ratios=None, + flip=True, + transforms=[ + dict(type='Resize', keep_ratio=False), + dict(type='Pad', size_divisor=32), + dict(type='RandomFlip'), + ]) + tta_module = build_from_cfg(tta_transform, PIPELINES) + + tta_transform = dict( + type='MultiScaleFlipAug', + img_scale=[(256, 256), (512, 512), (1024, 1024)], + img_ratios=None, + flip=True, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Pad', size_divisor=32), + ]) + tta_module = build_from_cfg(tta_transform, PIPELINES) + tta_results = tta_module(results.copy()) + assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512), + (512, 512), (1024, 1024), (1024, 1024)] + assert tta_results['flip'] == [False, True, False, True, False, True] + assert tta_results['img_shape'] == [(144, 256, 3), (144, 256, 3), + (288, 512, 3), (288, 512, 3), + (576, 1024, 3), (576, 1024, 3)] + assert tta_results['pad_shape'] == [(160, 256, 3), (160, 256, 3), + (288, 512, 3), (288, 512, 3), + (576, 1024, 3), (576, 1024, 3)] + for i in range(len(tta_results['img'])): + assert tta_results['img'][i].shape == tta_results['pad_shape'][i]