# Copyright (c) OpenMMLab. All rights reserved. # import os.path as osp # import mmcv # import pytest # from mmseg.datasets.transforms import * # noqa # from mmseg.registry import TRANSFORMS # TODO # def test_multi_scale_flip_aug(): # # test assertion if scales=None, scale_factor=1 (not float). # with pytest.raises(AssertionError): # tta_transform = dict( # type='MultiScaleFlipAug', # scales=None, # scale_factor=1, # transforms=[dict(type='Resize', keep_ratio=False)], # ) # TRANSFORMS.build(tta_transform) # # test assertion if scales=None, scale_factor=None. # with pytest.raises(AssertionError): # tta_transform = dict( # type='MultiScaleFlipAug', # scales=None, # scale_factor=None, # transforms=[dict(type='Resize', keep_ratio=False)], # ) # TRANSFORMS.build(tta_transform) # # test assertion if scales=(512, 512), scale_factor=1 (not float). # with pytest.raises(AssertionError): # tta_transform = dict( # type='MultiScaleFlipAug', # scales=(512, 512), # scale_factor=1, # transforms=[dict(type='Resize', keep_ratio=False)], # ) # TRANSFORMS.build(tta_transform) # meta_keys = ('img', 'ori_shape', 'ori_height', 'ori_width', 'pad_shape', # 'scale_factor', 'scale', 'flip') # tta_transform = dict( # type='MultiScaleFlipAug', # scales=[(256, 256), (512, 512), (1024, 1024)], # allow_flip=False, # resize_cfg=dict(type='Resize', keep_ratio=False), # transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], # ) # tta_module = TRANSFORMS.build(tta_transform) # results = dict() # # (288, 512, 3) # img = mmcv.imread( # osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') # results['img'] = img # results['ori_shape'] = img.shape # results['ori_height'] = img.shape[0] # results['ori_width'] = img.shape[1] # # Set initial values for default meta_keys # results['pad_shape'] = img.shape # results['scale_factor'] = 1.0 # tta_results = tta_module(results.copy()) # assert [data_sample.scale # for data_sample in tta_results['data_sample']] == [(256, 256), # (512, 512), # (1024, 1024)] # assert [data_sample.flip for data_sample in tta_results['data_sample'] # ] == [False, False, False] # tta_transform = dict( # type='MultiScaleFlipAug', # scales=[(256, 256), (512, 512), (1024, 1024)], # allow_flip=True, # resize_cfg=dict(type='Resize', keep_ratio=False), # transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], # ) # tta_module = TRANSFORMS.build(tta_transform) # tta_results = tta_module(results.copy()) # assert [data_sample.scale # for data_sample in tta_results['data_sample']] == [(256, 256), # (256, 256), # (512, 512), # (512, 512), # (1024, 1024), # (1024, 1024)] # assert [data_sample.flip for data_sample in tta_results['data_sample'] # ] == [False, True, False, True, False, True] # tta_transform = dict( # type='MultiScaleFlipAug', # scales=[(512, 512)], # allow_flip=False, # resize_cfg=dict(type='Resize', keep_ratio=False), # transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], # ) # tta_module = TRANSFORMS.build(tta_transform) # tta_results = tta_module(results.copy()) # assert [tta_results['data_sample'][0].scale] == [(512, 512)] # assert [tta_results['data_sample'][0].flip] == [False] # tta_transform = dict( # type='MultiScaleFlipAug', # scales=[(512, 512)], # allow_flip=True, # resize_cfg=dict(type='Resize', keep_ratio=False), # transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], # ) # tta_module = TRANSFORMS.build(tta_transform) # tta_results = tta_module(results.copy()) # assert [data_sample.scale # for data_sample in tta_results['data_sample']] == [(512, 512), # (512, 512)] # assert [data_sample.flip # for data_sample in tta_results['data_sample']] == [False, True] # tta_transform = dict( # type='MultiScaleFlipAug', # scale_factor=[0.5, 1.0, 2.0], # allow_flip=False, # resize_cfg=dict(type='Resize', keep_ratio=False), # transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], # ) # tta_module = TRANSFORMS.build(tta_transform) # tta_results = tta_module(results.copy()) # assert [data_sample.scale # for data_sample in tta_results['data_sample']] == [(256, 144), # (512, 288), # (1024, 576)] # assert [data_sample.flip for data_sample in tta_results['data_sample'] # ] == [False, False, False] # tta_transform = dict( # type='MultiScaleFlipAug', # scale_factor=[0.5, 1.0, 2.0], # allow_flip=True, # resize_cfg=dict(type='Resize', keep_ratio=False), # transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], # ) # tta_module = TRANSFORMS.build(tta_transform) # tta_results = tta_module(results.copy()) # assert [data_sample.scale # for data_sample in tta_results['data_sample']] == [(256, 144), # (256, 144), # (512, 288), # (512, 288), # (1024, 576), # (1024, 576)] # assert [data_sample.flip for data_sample in tta_results['data_sample'] # ] == [False, True, False, True, False, True]