152 lines
6.4 KiB
Python
152 lines
6.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# import os.path as osp
|
|
|
|
# import mmcv
|
|
# import pytest
|
|
|
|
# from mmseg.datasets.transforms import * # noqa
|
|
# from mmseg.registry import TRANSFORMS
|
|
|
|
# TODO
|
|
# def test_multi_scale_flip_aug():
|
|
# # test assertion if scales=None, scale_factor=1 (not float).
|
|
# with pytest.raises(AssertionError):
|
|
# tta_transform = dict(
|
|
# type='MultiScaleFlipAug',
|
|
# scales=None,
|
|
# scale_factor=1,
|
|
# transforms=[dict(type='Resize', keep_ratio=False)],
|
|
# )
|
|
# TRANSFORMS.build(tta_transform)
|
|
|
|
# # test assertion if scales=None, scale_factor=None.
|
|
# with pytest.raises(AssertionError):
|
|
# tta_transform = dict(
|
|
# type='MultiScaleFlipAug',
|
|
# scales=None,
|
|
# scale_factor=None,
|
|
# transforms=[dict(type='Resize', keep_ratio=False)],
|
|
# )
|
|
# TRANSFORMS.build(tta_transform)
|
|
|
|
# # test assertion if scales=(512, 512), scale_factor=1 (not float).
|
|
# with pytest.raises(AssertionError):
|
|
# tta_transform = dict(
|
|
# type='MultiScaleFlipAug',
|
|
# scales=(512, 512),
|
|
# scale_factor=1,
|
|
# transforms=[dict(type='Resize', keep_ratio=False)],
|
|
# )
|
|
# TRANSFORMS.build(tta_transform)
|
|
# meta_keys = ('img', 'ori_shape', 'ori_height', 'ori_width', 'pad_shape',
|
|
# 'scale_factor', 'scale', 'flip')
|
|
# tta_transform = dict(
|
|
# type='MultiScaleFlipAug',
|
|
# scales=[(256, 256), (512, 512), (1024, 1024)],
|
|
# allow_flip=False,
|
|
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
|
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
|
# )
|
|
# tta_module = TRANSFORMS.build(tta_transform)
|
|
|
|
# results = dict()
|
|
# # (288, 512, 3)
|
|
# img = mmcv.imread(
|
|
# osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
|
# results['img'] = img
|
|
# results['ori_shape'] = img.shape
|
|
# results['ori_height'] = img.shape[0]
|
|
# results['ori_width'] = img.shape[1]
|
|
# # Set initial values for default meta_keys
|
|
# results['pad_shape'] = img.shape
|
|
# results['scale_factor'] = 1.0
|
|
|
|
# tta_results = tta_module(results.copy())
|
|
# assert [data_sample.scale
|
|
# for data_sample in tta_results['data_sample']] == [(256, 256),
|
|
# (512, 512),
|
|
# (1024, 1024)]
|
|
# assert [data_sample.flip for data_sample in tta_results['data_sample']
|
|
# ] == [False, False, False]
|
|
|
|
# tta_transform = dict(
|
|
# type='MultiScaleFlipAug',
|
|
# scales=[(256, 256), (512, 512), (1024, 1024)],
|
|
# allow_flip=True,
|
|
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
|
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
|
# )
|
|
# tta_module = TRANSFORMS.build(tta_transform)
|
|
# tta_results = tta_module(results.copy())
|
|
# assert [data_sample.scale
|
|
# for data_sample in tta_results['data_sample']] == [(256, 256),
|
|
# (256, 256),
|
|
# (512, 512),
|
|
# (512, 512),
|
|
# (1024, 1024),
|
|
# (1024, 1024)]
|
|
# assert [data_sample.flip for data_sample in tta_results['data_sample']
|
|
# ] == [False, True, False, True, False, True]
|
|
|
|
# tta_transform = dict(
|
|
# type='MultiScaleFlipAug',
|
|
# scales=[(512, 512)],
|
|
# allow_flip=False,
|
|
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
|
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
|
# )
|
|
# tta_module = TRANSFORMS.build(tta_transform)
|
|
# tta_results = tta_module(results.copy())
|
|
# assert [tta_results['data_sample'][0].scale] == [(512, 512)]
|
|
# assert [tta_results['data_sample'][0].flip] == [False]
|
|
|
|
# tta_transform = dict(
|
|
# type='MultiScaleFlipAug',
|
|
# scales=[(512, 512)],
|
|
# allow_flip=True,
|
|
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
|
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
|
# )
|
|
# tta_module = TRANSFORMS.build(tta_transform)
|
|
# tta_results = tta_module(results.copy())
|
|
# assert [data_sample.scale
|
|
# for data_sample in tta_results['data_sample']] == [(512, 512),
|
|
# (512, 512)]
|
|
# assert [data_sample.flip
|
|
# for data_sample in tta_results['data_sample']] == [False, True]
|
|
|
|
# tta_transform = dict(
|
|
# type='MultiScaleFlipAug',
|
|
# scale_factor=[0.5, 1.0, 2.0],
|
|
# allow_flip=False,
|
|
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
|
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
|
# )
|
|
# tta_module = TRANSFORMS.build(tta_transform)
|
|
# tta_results = tta_module(results.copy())
|
|
# assert [data_sample.scale
|
|
# for data_sample in tta_results['data_sample']] == [(256, 144),
|
|
# (512, 288),
|
|
# (1024, 576)]
|
|
# assert [data_sample.flip for data_sample in tta_results['data_sample']
|
|
# ] == [False, False, False]
|
|
|
|
# tta_transform = dict(
|
|
# type='MultiScaleFlipAug',
|
|
# scale_factor=[0.5, 1.0, 2.0],
|
|
# allow_flip=True,
|
|
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
|
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
|
# )
|
|
# tta_module = TRANSFORMS.build(tta_transform)
|
|
# tta_results = tta_module(results.copy())
|
|
# assert [data_sample.scale
|
|
# for data_sample in tta_results['data_sample']] == [(256, 144),
|
|
# (256, 144),
|
|
# (512, 288),
|
|
# (512, 288),
|
|
# (1024, 576),
|
|
# (1024, 576)]
|
|
# assert [data_sample.flip for data_sample in tta_results['data_sample']
|
|
# ] == [False, True, False, True, False, True]
|