# Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import mmcv import pytest from mmseg.datasets.pipelines import * # noqa from mmseg.registry import TRANSFORMS def test_multi_scale_flip_aug(): # test assertion if scales=None, scale_factor=1 (not float). with pytest.raises(AssertionError): tta_transform = dict( type='MultiScaleFlipAug', scales=None, scale_factor=1, transforms=[dict(type='Resize', keep_ratio=False)], ) TRANSFORMS.build(tta_transform) # test assertion if scales=None, scale_factor=None. with pytest.raises(AssertionError): tta_transform = dict( type='MultiScaleFlipAug', scales=None, scale_factor=None, transforms=[dict(type='Resize', keep_ratio=False)], ) TRANSFORMS.build(tta_transform) # test assertion if scales=(512, 512), scale_factor=1 (not float). with pytest.raises(AssertionError): tta_transform = dict( type='MultiScaleFlipAug', scales=(512, 512), scale_factor=1, transforms=[dict(type='Resize', keep_ratio=False)], ) TRANSFORMS.build(tta_transform) meta_keys = ('img', 'ori_shape', 'ori_height', 'ori_width', 'pad_shape', 'scale_factor', 'scale', 'flip') tta_transform = dict( type='MultiScaleFlipAug', scales=[(256, 256), (512, 512), (1024, 1024)], allow_flip=False, resize_cfg=dict(type='Resize', keep_ratio=False), transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], ) tta_module = TRANSFORMS.build(tta_transform) results = dict() # (288, 512, 3) img = mmcv.imread( osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') results['img'] = img results['ori_shape'] = img.shape results['ori_height'] = img.shape[0] results['ori_width'] = img.shape[1] # Set initial values for default meta_keys results['pad_shape'] = img.shape results['scale_factor'] = 1.0 tta_results = tta_module(results.copy()) assert [data_sample.scale for data_sample in tta_results['data_sample']] == [(256, 256), (512, 512), (1024, 1024)] assert [data_sample.flip for data_sample in tta_results['data_sample'] ] == [False, False, False] tta_transform = dict( type='MultiScaleFlipAug', scales=[(256, 256), (512, 512), (1024, 1024)], allow_flip=True, resize_cfg=dict(type='Resize', keep_ratio=False), transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], ) tta_module = TRANSFORMS.build(tta_transform) tta_results = tta_module(results.copy()) assert [data_sample.scale for data_sample in tta_results['data_sample']] == [(256, 256), (256, 256), (512, 512), (512, 512), (1024, 1024), (1024, 1024)] assert [data_sample.flip for data_sample in tta_results['data_sample'] ] == [False, True, False, True, False, True] tta_transform = dict( type='MultiScaleFlipAug', scales=[(512, 512)], allow_flip=False, resize_cfg=dict(type='Resize', keep_ratio=False), transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], ) tta_module = TRANSFORMS.build(tta_transform) tta_results = tta_module(results.copy()) assert [tta_results['data_sample'][0].scale] == [(512, 512)] assert [tta_results['data_sample'][0].flip] == [False] tta_transform = dict( type='MultiScaleFlipAug', scales=[(512, 512)], allow_flip=True, resize_cfg=dict(type='Resize', keep_ratio=False), transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], ) tta_module = TRANSFORMS.build(tta_transform) tta_results = tta_module(results.copy()) assert [data_sample.scale for data_sample in tta_results['data_sample']] == [(512, 512), (512, 512)] assert [data_sample.flip for data_sample in tta_results['data_sample']] == [False, True] tta_transform = dict( type='MultiScaleFlipAug', scale_factor=[0.5, 1.0, 2.0], allow_flip=False, resize_cfg=dict(type='Resize', keep_ratio=False), transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], ) tta_module = TRANSFORMS.build(tta_transform) tta_results = tta_module(results.copy()) assert [data_sample.scale for data_sample in tta_results['data_sample']] == [(256, 144), (512, 288), (1024, 576)] assert [data_sample.flip for data_sample in tta_results['data_sample'] ] == [False, False, False] tta_transform = dict( type='MultiScaleFlipAug', scale_factor=[0.5, 1.0, 2.0], allow_flip=True, resize_cfg=dict(type='Resize', keep_ratio=False), transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], ) tta_module = TRANSFORMS.build(tta_transform) tta_results = tta_module(results.copy()) assert [data_sample.scale for data_sample in tta_results['data_sample']] == [(256, 144), (256, 144), (512, 288), (512, 288), (1024, 576), (1024, 576)] assert [data_sample.flip for data_sample in tta_results['data_sample'] ] == [False, True, False, True, False, True]