[Enchance] support infererence with padding (#1607)

* [Enchance] support infererence with padding

* limite pad after flip when inference

* add test code
This commit is contained in:
FreyWang 2022-06-15 11:28:09 +08:00 committed by GitHub
parent 2dede04703
commit 6f43f4d5d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 52 additions and 0 deletions

View File

@ -57,6 +57,14 @@ class MultiScaleFlipAug(object):
img_ratios=None,
flip=False,
flip_direction='horizontal'):
if flip:
trans_index = {
key['type']: index
for index, key in enumerate(transforms)
}
if 'RandomFlip' in trans_index and 'Pad' in trans_index:
assert trans_index['RandomFlip'] < trans_index['Pad'], \
'Pad must be executed after RandomFlip when flip is True'
self.transforms = Compose(transforms)
if img_ratios is not None:
img_ratios = img_ratios if isinstance(img_ratios,

View File

@ -189,6 +189,9 @@ class EncoderDecoder(BaseSegmentor):
count_mat.cpu().detach().numpy()).to(device=img.device)
preds = preds / count_mat
if rescale:
# remove padding area
resize_shape = img_meta[0]['img_shape'][:2]
preds = preds[:, :, :resize_shape[0], :resize_shape[1]]
preds = resize(
preds,
size=img_meta[0]['ori_shape'][:2],
@ -206,6 +209,9 @@ class EncoderDecoder(BaseSegmentor):
if torch.onnx.is_in_onnx_export():
size = img.shape[2:]
else:
# remove padding area
resize_shape = img_meta[0]['img_shape'][:2]
seg_logit = seg_logit[:, :, :resize_shape[0], :resize_shape[1]]
size = img_meta[0]['ori_shape'][:2]
seg_logit = resize(
seg_logit,

View File

@ -149,3 +149,41 @@ def test_multi_scale_flip_aug():
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
(512, 512), (1024, 1024), (1024, 1024)]
assert tta_results['flip'] == [False, True, False, True, False, True]
# test assertion if flip is True and Pad executed before RandomFlip
with pytest.raises(AssertionError):
tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=[(256, 256), (512, 512), (1024, 1024)],
img_ratios=None,
flip=True,
transforms=[
dict(type='Resize', keep_ratio=False),
dict(type='Pad', size_divisor=32),
dict(type='RandomFlip'),
])
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=[(256, 256), (512, 512), (1024, 1024)],
img_ratios=None,
flip=True,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Pad', size_divisor=32),
])
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
(512, 512), (1024, 1024), (1024, 1024)]
assert tta_results['flip'] == [False, True, False, True, False, True]
assert tta_results['img_shape'] == [(144, 256, 3), (144, 256, 3),
(288, 512, 3), (288, 512, 3),
(576, 1024, 3), (576, 1024, 3)]
assert tta_results['pad_shape'] == [(160, 256, 3), (160, 256, 3),
(288, 512, 3), (288, 512, 3),
(576, 1024, 3), (576, 1024, 3)]
for i in range(len(tta_results['img'])):
assert tta_results['img'][i].shape == tta_results['pad_shape'][i]