mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Enchance] support infererence with padding (#1607)
* [Enchance] support infererence with padding * limite pad after flip when inference * add test code
This commit is contained in:
parent
2dede04703
commit
6f43f4d5d3
@ -57,6 +57,14 @@ class MultiScaleFlipAug(object):
|
|||||||
img_ratios=None,
|
img_ratios=None,
|
||||||
flip=False,
|
flip=False,
|
||||||
flip_direction='horizontal'):
|
flip_direction='horizontal'):
|
||||||
|
if flip:
|
||||||
|
trans_index = {
|
||||||
|
key['type']: index
|
||||||
|
for index, key in enumerate(transforms)
|
||||||
|
}
|
||||||
|
if 'RandomFlip' in trans_index and 'Pad' in trans_index:
|
||||||
|
assert trans_index['RandomFlip'] < trans_index['Pad'], \
|
||||||
|
'Pad must be executed after RandomFlip when flip is True'
|
||||||
self.transforms = Compose(transforms)
|
self.transforms = Compose(transforms)
|
||||||
if img_ratios is not None:
|
if img_ratios is not None:
|
||||||
img_ratios = img_ratios if isinstance(img_ratios,
|
img_ratios = img_ratios if isinstance(img_ratios,
|
||||||
|
@ -189,6 +189,9 @@ class EncoderDecoder(BaseSegmentor):
|
|||||||
count_mat.cpu().detach().numpy()).to(device=img.device)
|
count_mat.cpu().detach().numpy()).to(device=img.device)
|
||||||
preds = preds / count_mat
|
preds = preds / count_mat
|
||||||
if rescale:
|
if rescale:
|
||||||
|
# remove padding area
|
||||||
|
resize_shape = img_meta[0]['img_shape'][:2]
|
||||||
|
preds = preds[:, :, :resize_shape[0], :resize_shape[1]]
|
||||||
preds = resize(
|
preds = resize(
|
||||||
preds,
|
preds,
|
||||||
size=img_meta[0]['ori_shape'][:2],
|
size=img_meta[0]['ori_shape'][:2],
|
||||||
@ -206,6 +209,9 @@ class EncoderDecoder(BaseSegmentor):
|
|||||||
if torch.onnx.is_in_onnx_export():
|
if torch.onnx.is_in_onnx_export():
|
||||||
size = img.shape[2:]
|
size = img.shape[2:]
|
||||||
else:
|
else:
|
||||||
|
# remove padding area
|
||||||
|
resize_shape = img_meta[0]['img_shape'][:2]
|
||||||
|
seg_logit = seg_logit[:, :, :resize_shape[0], :resize_shape[1]]
|
||||||
size = img_meta[0]['ori_shape'][:2]
|
size = img_meta[0]['ori_shape'][:2]
|
||||||
seg_logit = resize(
|
seg_logit = resize(
|
||||||
seg_logit,
|
seg_logit,
|
||||||
|
@ -149,3 +149,41 @@ def test_multi_scale_flip_aug():
|
|||||||
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
|
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
|
||||||
(512, 512), (1024, 1024), (1024, 1024)]
|
(512, 512), (1024, 1024), (1024, 1024)]
|
||||||
assert tta_results['flip'] == [False, True, False, True, False, True]
|
assert tta_results['flip'] == [False, True, False, True, False, True]
|
||||||
|
|
||||||
|
# test assertion if flip is True and Pad executed before RandomFlip
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
tta_transform = dict(
|
||||||
|
type='MultiScaleFlipAug',
|
||||||
|
img_scale=[(256, 256), (512, 512), (1024, 1024)],
|
||||||
|
img_ratios=None,
|
||||||
|
flip=True,
|
||||||
|
transforms=[
|
||||||
|
dict(type='Resize', keep_ratio=False),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='RandomFlip'),
|
||||||
|
])
|
||||||
|
tta_module = build_from_cfg(tta_transform, PIPELINES)
|
||||||
|
|
||||||
|
tta_transform = dict(
|
||||||
|
type='MultiScaleFlipAug',
|
||||||
|
img_scale=[(256, 256), (512, 512), (1024, 1024)],
|
||||||
|
img_ratios=None,
|
||||||
|
flip=True,
|
||||||
|
transforms=[
|
||||||
|
dict(type='Resize', keep_ratio=True),
|
||||||
|
dict(type='RandomFlip'),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
])
|
||||||
|
tta_module = build_from_cfg(tta_transform, PIPELINES)
|
||||||
|
tta_results = tta_module(results.copy())
|
||||||
|
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
|
||||||
|
(512, 512), (1024, 1024), (1024, 1024)]
|
||||||
|
assert tta_results['flip'] == [False, True, False, True, False, True]
|
||||||
|
assert tta_results['img_shape'] == [(144, 256, 3), (144, 256, 3),
|
||||||
|
(288, 512, 3), (288, 512, 3),
|
||||||
|
(576, 1024, 3), (576, 1024, 3)]
|
||||||
|
assert tta_results['pad_shape'] == [(160, 256, 3), (160, 256, 3),
|
||||||
|
(288, 512, 3), (288, 512, 3),
|
||||||
|
(576, 1024, 3), (576, 1024, 3)]
|
||||||
|
for i in range(len(tta_results['img'])):
|
||||||
|
assert tta_results['img'][i].shape == tta_results['pad_shape'][i]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user