[Refactor] Support TTA (#2184)

* tta init

* use mmcv transform

* test city

* add multiscale

* fix merge

* add softmax to post process

* add ut

* add tta pipeline to other datasets

* remove softmax

* add encoder_decoder_tta ut

* add encoder_decoder_tta ut

* rename

* rename file

* rename config

* rm aug_test

* move flip to post process

* fix channel
pull/2448/head
谢昕辰 2022-12-30 13:46:52 +08:00 committed by MeowZheng
parent 20a6c58478
commit da4125587e
24 changed files with 506 additions and 147 deletions

View File

@ -23,6 +23,22 @@ test_pipeline = [
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,

View File

@ -23,6 +23,22 @@ test_pipeline = [
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,

View File

@ -24,6 +24,22 @@ test_pipeline = [
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,

View File

@ -23,6 +23,22 @@ test_pipeline = [
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=2,
num_workers=2,

View File

@ -23,6 +23,22 @@ test_pipeline = [
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,

View File

@ -23,6 +23,22 @@ test_pipeline = [
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,

View File

@ -24,6 +24,22 @@ test_pipeline = [
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,

View File

@ -24,6 +24,22 @@ test_pipeline = [
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,

View File

@ -30,6 +30,22 @@ test_pipeline = [
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,

View File

@ -23,6 +23,22 @@ test_pipeline = [
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,

View File

@ -26,6 +26,22 @@ test_pipeline = [
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,

View File

@ -23,6 +23,22 @@ test_pipeline = [
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,

View File

@ -25,7 +25,22 @@ test_pipeline = [
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
dataset_train = dict(
type=dataset_type,
data_root=data_root,

View File

@ -23,6 +23,22 @@ test_pipeline = [
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,

View File

@ -24,6 +24,22 @@ test_pipeline = [
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,

View File

@ -23,6 +23,22 @@ test_pipeline = [
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,

View File

@ -11,3 +11,5 @@ log_processor = dict(by_epoch=False)
log_level = 'INFO'
load_from = None
resume = False
tta_model = dict(type='SegTTAModel')

View File

@ -2,5 +2,8 @@
from .base import BaseSegmentor
from .cascade_encoder_decoder import CascadeEncoderDecoder
from .encoder_decoder import EncoderDecoder
from .seg_tta import SegTTAModel
__all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder']
__all__ = [
'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel'
]

View File

@ -124,11 +124,6 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
"""
pass
@abstractmethod
def aug_test(self, batch_inputs, batch_img_metas):
"""Placeholder for augmentation test."""
pass
def postprocess_result(self,
seg_logits: Tensor,
data_samples: OptSampleList = None) -> list:
@ -170,6 +165,15 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
padding_top:H - padding_bottom,
padding_left:W - padding_right]
flip = img_meta.get('flip', None)
if flip:
flip_direction = img_meta.get('flip_direction', None)
assert flip_direction in ['horizontal', 'vertical']
if flip_direction == 'horizontal':
i_seg_logits = i_seg_logits.flip(dims=(3, ))
else:
i_seg_logits = i_seg_logits.flip(dims=(2, ))
# resize as original shape
i_seg_logits = resize(
i_seg_logits,

View File

@ -0,0 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
from mmengine.model import BaseTTAModel
from mmengine.structures import PixelData
from mmseg.registry import MODELS
from mmseg.structures import SegDataSample
from mmseg.utils import SampleList
@MODELS.register_module()
class SegTTAModel(BaseTTAModel):
def merge_preds(self, data_samples_list: List[SampleList]) -> SampleList:
"""Merge predictions of enhanced data to one prediction.
Args:
data_samples_list (List[SampleList]): List of predictions
of all enhanced data.
Returns:
SampleList: Merged prediction.
"""
predictions = []
for data_samples in data_samples_list:
seg_logits = data_samples[0].seg_logits.data
logits = torch.zeros(seg_logits.shape).to(seg_logits)
for data_sample in data_samples:
seg_logit = data_sample.seg_logits.data
if self.module.out_channels > 1:
logits += seg_logit.softmax(dim=0)
else:
logits += seg_logit.sigmoid()
logits /= len(data_samples)
if self.module.out_channels == 1:
seg_pred = (logits > self.module.decode_head.threshold
).to(logits).squeeze(1)
else:
seg_pred = logits.argmax(dim=0)
data_sample = SegDataSample(
**{
'pred_sem_seg': PixelData(data=seg_pred),
'gt_sem_seg': data_samples[0].gt_sem_seg
})
predictions.append(data_sample)
return predictions

View File

@ -1,151 +1,131 @@
# Copyright (c) OpenMMLab. All rights reserved.
# import os.path as osp
import os.path as osp
# import mmcv
# import pytest
import mmcv
import pytest
# from mmseg.datasets.transforms import * # noqa
# from mmseg.registry import TRANSFORMS
from mmseg.datasets.transforms import * # noqa
from mmseg.registry import TRANSFORMS
# TODO
# def test_multi_scale_flip_aug():
# # test assertion if scales=None, scale_factor=1 (not float).
# with pytest.raises(AssertionError):
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scales=None,
# scale_factor=1,
# transforms=[dict(type='Resize', keep_ratio=False)],
# )
# TRANSFORMS.build(tta_transform)
# # test assertion if scales=None, scale_factor=None.
# with pytest.raises(AssertionError):
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scales=None,
# scale_factor=None,
# transforms=[dict(type='Resize', keep_ratio=False)],
# )
# TRANSFORMS.build(tta_transform)
def test_multi_scale_flip_aug():
# test exception
with pytest.raises(TypeError):
tta_transform = dict(
type='TestTimeAug',
transforms=[dict(type='Resize', keep_ratio=False)],
)
TRANSFORMS.build(tta_transform)
# # test assertion if scales=(512, 512), scale_factor=1 (not float).
# with pytest.raises(AssertionError):
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scales=(512, 512),
# scale_factor=1,
# transforms=[dict(type='Resize', keep_ratio=False)],
# )
# TRANSFORMS.build(tta_transform)
# meta_keys = ('img', 'ori_shape', 'ori_height', 'ori_width', 'pad_shape',
# 'scale_factor', 'scale', 'flip')
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scales=[(256, 256), (512, 512), (1024, 1024)],
# allow_flip=False,
# resize_cfg=dict(type='Resize', keep_ratio=False),
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
# )
# tta_module = TRANSFORMS.build(tta_transform)
tta_transform = dict(
type='TestTimeAug',
transforms=[[
dict(type='Resize', scale=scale, keep_ratio=False)
for scale in [(256, 256), (512, 512), (1024, 1024)]
], [dict(type='mmseg.PackSegInputs')]])
tta_module = TRANSFORMS.build(tta_transform)
# results = dict()
# # (288, 512, 3)
# img = mmcv.imread(
# osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
# results['img'] = img
# results['ori_shape'] = img.shape
# results['ori_height'] = img.shape[0]
# results['ori_width'] = img.shape[1]
# # Set initial values for default meta_keys
# results['pad_shape'] = img.shape
# results['scale_factor'] = 1.0
results = dict()
# (288, 512, 3)
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
results['img'] = img
results['ori_shape'] = img.shape
results['ori_height'] = img.shape[0]
results['ori_width'] = img.shape[1]
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
# tta_results = tta_module(results.copy())
# assert [data_sample.scale
# for data_sample in tta_results['data_sample']] == [(256, 256),
# (512, 512),
# (1024, 1024)]
# assert [data_sample.flip for data_sample in tta_results['data_sample']
# ] == [False, False, False]
tta_results = tta_module(results.copy())
assert [img.shape for img in tta_results['inputs']] == [(3, 256, 256),
(3, 512, 512),
(3, 1024, 1024)]
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scales=[(256, 256), (512, 512), (1024, 1024)],
# allow_flip=True,
# resize_cfg=dict(type='Resize', keep_ratio=False),
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
# )
# tta_module = TRANSFORMS.build(tta_transform)
# tta_results = tta_module(results.copy())
# assert [data_sample.scale
# for data_sample in tta_results['data_sample']] == [(256, 256),
# (256, 256),
# (512, 512),
# (512, 512),
# (1024, 1024),
# (1024, 1024)]
# assert [data_sample.flip for data_sample in tta_results['data_sample']
# ] == [False, True, False, True, False, True]
tta_transform = dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale=scale, keep_ratio=False)
for scale in [(256, 256), (512, 512), (1024, 1024)]
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='mmseg.PackSegInputs')]
])
tta_module = TRANSFORMS.build(tta_transform)
tta_results: dict = tta_module(results.copy())
assert [img.shape for img in tta_results['inputs']] == [(3, 256, 256),
(3, 256, 256),
(3, 512, 512),
(3, 512, 512),
(3, 1024, 1024),
(3, 1024, 1024)]
assert [
data_sample.metainfo['flip']
for data_sample in tta_results['data_samples']
] == [False, True, False, True, False, True]
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scales=[(512, 512)],
# allow_flip=False,
# resize_cfg=dict(type='Resize', keep_ratio=False),
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
# )
# tta_module = TRANSFORMS.build(tta_transform)
# tta_results = tta_module(results.copy())
# assert [tta_results['data_sample'][0].scale] == [(512, 512)]
# assert [tta_results['data_sample'][0].flip] == [False]
tta_transform = dict(
type='TestTimeAug',
transforms=[[dict(type='Resize', scale=(512, 512), keep_ratio=False)],
[dict(type='mmseg.PackSegInputs')]])
tta_module = TRANSFORMS.build(tta_transform)
tta_results = tta_module(results.copy())
assert [tta_results['inputs'][0].shape] == [(3, 512, 512)]
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scales=[(512, 512)],
# allow_flip=True,
# resize_cfg=dict(type='Resize', keep_ratio=False),
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
# )
# tta_module = TRANSFORMS.build(tta_transform)
# tta_results = tta_module(results.copy())
# assert [data_sample.scale
# for data_sample in tta_results['data_sample']] == [(512, 512),
# (512, 512)]
# assert [data_sample.flip
# for data_sample in tta_results['data_sample']] == [False, True]
tta_transform = dict(
type='TestTimeAug',
transforms=[
[dict(type='Resize', scale=(512, 512), keep_ratio=False)],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='mmseg.PackSegInputs')]
])
tta_module = TRANSFORMS.build(tta_transform)
tta_results = tta_module(results.copy())
assert [img.shape for img in tta_results['inputs']] == [(3, 512, 512),
(3, 512, 512)]
assert [
data_sample.metainfo['flip']
for data_sample in tta_results['data_samples']
] == [False, True]
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scale_factor=[0.5, 1.0, 2.0],
# allow_flip=False,
# resize_cfg=dict(type='Resize', keep_ratio=False),
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
# )
# tta_module = TRANSFORMS.build(tta_transform)
# tta_results = tta_module(results.copy())
# assert [data_sample.scale
# for data_sample in tta_results['data_sample']] == [(256, 144),
# (512, 288),
# (1024, 576)]
# assert [data_sample.flip for data_sample in tta_results['data_sample']
# ] == [False, False, False]
tta_transform = dict(
type='TestTimeAug',
transforms=[[
dict(type='Resize', scale_factor=r, keep_ratio=False)
for r in [0.5, 1.0, 2.0]
], [dict(type='mmseg.PackSegInputs')]])
tta_module = TRANSFORMS.build(tta_transform)
tta_results = tta_module(results.copy())
assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256),
(3, 288, 512),
(3, 576, 1024)]
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scale_factor=[0.5, 1.0, 2.0],
# allow_flip=True,
# resize_cfg=dict(type='Resize', keep_ratio=False),
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
# )
# tta_module = TRANSFORMS.build(tta_transform)
# tta_results = tta_module(results.copy())
# assert [data_sample.scale
# for data_sample in tta_results['data_sample']] == [(256, 144),
# (256, 144),
# (512, 288),
# (512, 288),
# (1024, 576),
# (1024, 576)]
# assert [data_sample.flip for data_sample in tta_results['data_sample']
# ] == [False, True, False, True, False, True]
tta_transform = dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in [0.5, 1.0, 2.0]
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='mmseg.PackSegInputs')]
])
tta_module = TRANSFORMS.build(tta_transform)
tta_results = tta_module(results.copy())
assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256),
(3, 144, 256),
(3, 288, 512),
(3, 288, 512),
(3, 576, 1024),
(3, 576, 1024)]
assert [
data_sample.metainfo['flip']
for data_sample in tta_results['data_samples']
] == [False, True, False, True, False, True]

View File

@ -0,0 +1,60 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine import ConfigDict
from mmengine.model import BaseTTAModel
from mmengine.structures import PixelData
from mmseg.registry import MODELS
from mmseg.structures import SegDataSample
from mmseg.utils import register_all_modules
from .utils import * # noqa: F401,F403
register_all_modules()
def test_encoder_decoder_tta():
segmentor_cfg = ConfigDict(
type='EncoderDecoder',
backbone=dict(type='ExampleBackbone'),
decode_head=dict(type='ExampleDecodeHead'),
train_cfg=None,
test_cfg=dict(mode='whole'))
cfg = ConfigDict(type='SegTTAModel', module=segmentor_cfg)
model: BaseTTAModel = MODELS.build(cfg)
imgs = []
data_samples = []
directions = ['horizontal', 'vertical']
for i in range(12):
flip_direction = directions[0] if i % 3 == 0 else directions[1]
imgs.append(torch.randn(1, 3, 10 + i, 10 + i))
data_samples.append([
SegDataSample(
metainfo=dict(
ori_shape=(10, 10),
img_shape=(10 + i, 10 + i),
flip=(i % 2 == 0),
flip_direction=flip_direction),
gt_sem_seg=PixelData(data=torch.randint(0, 19, (1, 10, 10))))
])
model.test_step(dict(inputs=imgs, data_samples=data_samples))
# test out_channels == 1
segmentor_cfg = ConfigDict(
type='EncoderDecoder',
backbone=dict(type='ExampleBackbone'),
decode_head=dict(
type='ExampleDecodeHead',
num_classes=2,
out_channels=1,
threshold=0.4),
train_cfg=None,
test_cfg=dict(mode='whole'))
model.module = MODELS.build(segmentor_cfg)
for data_sample in data_samples:
data_sample[0].gt_sem_seg.data = torch.randint(0, 2, (1, 10, 10))
model.test_step(dict(inputs=imgs, data_samples=data_samples))

View File

@ -66,9 +66,9 @@ class ExampleBackbone(nn.Module):
@MODELS.register_module()
class ExampleDecodeHead(BaseDecodeHead):
def __init__(self, num_classes=19, out_channels=None):
def __init__(self, num_classes=19, out_channels=None, **kwargs):
super().__init__(
3, 3, num_classes=num_classes, out_channels=out_channels)
3, 3, num_classes=num_classes, out_channels=out_channels, **kwargs)
def forward(self, inputs):
return self.cls_seg(inputs[0])

View File

@ -43,6 +43,8 @@ def parse_args():
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument(
'--tta', action='store_true', help='Test time augmentation')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
@ -99,6 +101,11 @@ def main():
if args.show or args.show_dir:
cfg = trigger_visualization_hook(cfg, args)
if args.tta:
cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
cfg.tta_model.module = cfg.model
cfg.model = cfg.tta_model
# build the runner from config
runner = Runner.from_cfg(cfg)