mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Enchance] support infererence with padding (#1607)
* [Enchance] support infererence with padding * limite pad after flip when inference * add test code
This commit is contained in:
parent
2dede04703
commit
6f43f4d5d3
@ -57,6 +57,14 @@ class MultiScaleFlipAug(object):
|
||||
img_ratios=None,
|
||||
flip=False,
|
||||
flip_direction='horizontal'):
|
||||
if flip:
|
||||
trans_index = {
|
||||
key['type']: index
|
||||
for index, key in enumerate(transforms)
|
||||
}
|
||||
if 'RandomFlip' in trans_index and 'Pad' in trans_index:
|
||||
assert trans_index['RandomFlip'] < trans_index['Pad'], \
|
||||
'Pad must be executed after RandomFlip when flip is True'
|
||||
self.transforms = Compose(transforms)
|
||||
if img_ratios is not None:
|
||||
img_ratios = img_ratios if isinstance(img_ratios,
|
||||
|
@ -189,6 +189,9 @@ class EncoderDecoder(BaseSegmentor):
|
||||
count_mat.cpu().detach().numpy()).to(device=img.device)
|
||||
preds = preds / count_mat
|
||||
if rescale:
|
||||
# remove padding area
|
||||
resize_shape = img_meta[0]['img_shape'][:2]
|
||||
preds = preds[:, :, :resize_shape[0], :resize_shape[1]]
|
||||
preds = resize(
|
||||
preds,
|
||||
size=img_meta[0]['ori_shape'][:2],
|
||||
@ -206,6 +209,9 @@ class EncoderDecoder(BaseSegmentor):
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
size = img.shape[2:]
|
||||
else:
|
||||
# remove padding area
|
||||
resize_shape = img_meta[0]['img_shape'][:2]
|
||||
seg_logit = seg_logit[:, :, :resize_shape[0], :resize_shape[1]]
|
||||
size = img_meta[0]['ori_shape'][:2]
|
||||
seg_logit = resize(
|
||||
seg_logit,
|
||||
|
@ -149,3 +149,41 @@ def test_multi_scale_flip_aug():
|
||||
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
|
||||
(512, 512), (1024, 1024), (1024, 1024)]
|
||||
assert tta_results['flip'] == [False, True, False, True, False, True]
|
||||
|
||||
# test assertion if flip is True and Pad executed before RandomFlip
|
||||
with pytest.raises(AssertionError):
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=[(256, 256), (512, 512), (1024, 1024)],
|
||||
img_ratios=None,
|
||||
flip=True,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=False),
|
||||
dict(type='Pad', size_divisor=32),
|
||||
dict(type='RandomFlip'),
|
||||
])
|
||||
tta_module = build_from_cfg(tta_transform, PIPELINES)
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=[(256, 256), (512, 512), (1024, 1024)],
|
||||
img_ratios=None,
|
||||
flip=True,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Pad', size_divisor=32),
|
||||
])
|
||||
tta_module = build_from_cfg(tta_transform, PIPELINES)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
|
||||
(512, 512), (1024, 1024), (1024, 1024)]
|
||||
assert tta_results['flip'] == [False, True, False, True, False, True]
|
||||
assert tta_results['img_shape'] == [(144, 256, 3), (144, 256, 3),
|
||||
(288, 512, 3), (288, 512, 3),
|
||||
(576, 1024, 3), (576, 1024, 3)]
|
||||
assert tta_results['pad_shape'] == [(160, 256, 3), (160, 256, 3),
|
||||
(288, 512, 3), (288, 512, 3),
|
||||
(576, 1024, 3), (576, 1024, 3)]
|
||||
for i in range(len(tta_results['img'])):
|
||||
assert tta_results['img'][i].shape == tta_results['pad_shape'][i]
|
||||
|
Loading…
x
Reference in New Issue
Block a user