[Refactor] Support TTA (#2184)

* tta init

* use mmcv transform

* test city

* add multiscale

* fix merge

* add softmax to post process

* add ut

* add tta pipeline to other datasets

* remove softmax

* add encoder_decoder_tta ut

* add encoder_decoder_tta ut

* rename

* rename file

* rename config

* rm aug_test

* move flip to post process

* fix channel
This commit is contained in:
谢昕辰 2022-12-30 13:46:52 +08:00 committed by MeowZheng
parent 20a6c58478
commit da4125587e
24 changed files with 506 additions and 147 deletions

View File

@ -23,6 +23,22 @@ test_pipeline = [
dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs') dict(type='PackSegInputs')
] ]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict( train_dataloader = dict(
batch_size=4, batch_size=4,
num_workers=4, num_workers=4,

View File

@ -23,6 +23,22 @@ test_pipeline = [
dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs') dict(type='PackSegInputs')
] ]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict( train_dataloader = dict(
batch_size=4, batch_size=4,
num_workers=4, num_workers=4,

View File

@ -24,6 +24,22 @@ test_pipeline = [
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='PackSegInputs') dict(type='PackSegInputs')
] ]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict( train_dataloader = dict(
batch_size=4, batch_size=4,

View File

@ -23,6 +23,22 @@ test_pipeline = [
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='PackSegInputs') dict(type='PackSegInputs')
] ]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict( train_dataloader = dict(
batch_size=2, batch_size=2,
num_workers=2, num_workers=2,

View File

@ -23,6 +23,22 @@ test_pipeline = [
dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs') dict(type='PackSegInputs')
] ]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict( train_dataloader = dict(
batch_size=4, batch_size=4,
num_workers=4, num_workers=4,

View File

@ -23,6 +23,22 @@ test_pipeline = [
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='PackSegInputs') dict(type='PackSegInputs')
] ]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict( train_dataloader = dict(
batch_size=4, batch_size=4,
num_workers=4, num_workers=4,

View File

@ -24,6 +24,22 @@ test_pipeline = [
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='PackSegInputs') dict(type='PackSegInputs')
] ]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict( train_dataloader = dict(
batch_size=4, batch_size=4,
num_workers=4, num_workers=4,

View File

@ -24,6 +24,22 @@ test_pipeline = [
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='PackSegInputs') dict(type='PackSegInputs')
] ]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict( train_dataloader = dict(
batch_size=4, batch_size=4,
num_workers=4, num_workers=4,

View File

@ -30,6 +30,22 @@ test_pipeline = [
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='PackSegInputs') dict(type='PackSegInputs')
] ]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict( train_dataloader = dict(
batch_size=4, batch_size=4,
num_workers=4, num_workers=4,

View File

@ -23,6 +23,22 @@ test_pipeline = [
dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs') dict(type='PackSegInputs')
] ]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict( train_dataloader = dict(
batch_size=4, batch_size=4,
num_workers=4, num_workers=4,

View File

@ -26,6 +26,22 @@ test_pipeline = [
dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs') dict(type='PackSegInputs')
] ]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict( train_dataloader = dict(
batch_size=4, batch_size=4,
num_workers=4, num_workers=4,

View File

@ -23,6 +23,22 @@ test_pipeline = [
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='PackSegInputs') dict(type='PackSegInputs')
] ]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict( train_dataloader = dict(
batch_size=4, batch_size=4,
num_workers=4, num_workers=4,

View File

@ -25,7 +25,22 @@ test_pipeline = [
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='PackSegInputs') dict(type='PackSegInputs')
] ]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
dataset_train = dict( dataset_train = dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,

View File

@ -23,6 +23,22 @@ test_pipeline = [
dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs') dict(type='PackSegInputs')
] ]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict( train_dataloader = dict(
batch_size=4, batch_size=4,
num_workers=4, num_workers=4,

View File

@ -24,6 +24,22 @@ test_pipeline = [
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='PackSegInputs') dict(type='PackSegInputs')
] ]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict( train_dataloader = dict(
batch_size=4, batch_size=4,
num_workers=4, num_workers=4,

View File

@ -23,6 +23,22 @@ test_pipeline = [
dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs') dict(type='PackSegInputs')
] ]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict( train_dataloader = dict(
batch_size=4, batch_size=4,
num_workers=4, num_workers=4,

View File

@ -11,3 +11,5 @@ log_processor = dict(by_epoch=False)
log_level = 'INFO' log_level = 'INFO'
load_from = None load_from = None
resume = False resume = False
tta_model = dict(type='SegTTAModel')

View File

@ -2,5 +2,8 @@
from .base import BaseSegmentor from .base import BaseSegmentor
from .cascade_encoder_decoder import CascadeEncoderDecoder from .cascade_encoder_decoder import CascadeEncoderDecoder
from .encoder_decoder import EncoderDecoder from .encoder_decoder import EncoderDecoder
from .seg_tta import SegTTAModel
__all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder'] __all__ = [
'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel'
]

View File

@ -124,11 +124,6 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
""" """
pass pass
@abstractmethod
def aug_test(self, batch_inputs, batch_img_metas):
"""Placeholder for augmentation test."""
pass
def postprocess_result(self, def postprocess_result(self,
seg_logits: Tensor, seg_logits: Tensor,
data_samples: OptSampleList = None) -> list: data_samples: OptSampleList = None) -> list:
@ -170,6 +165,15 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
padding_top:H - padding_bottom, padding_top:H - padding_bottom,
padding_left:W - padding_right] padding_left:W - padding_right]
flip = img_meta.get('flip', None)
if flip:
flip_direction = img_meta.get('flip_direction', None)
assert flip_direction in ['horizontal', 'vertical']
if flip_direction == 'horizontal':
i_seg_logits = i_seg_logits.flip(dims=(3, ))
else:
i_seg_logits = i_seg_logits.flip(dims=(2, ))
# resize as original shape # resize as original shape
i_seg_logits = resize( i_seg_logits = resize(
i_seg_logits, i_seg_logits,

View File

@ -0,0 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
from mmengine.model import BaseTTAModel
from mmengine.structures import PixelData
from mmseg.registry import MODELS
from mmseg.structures import SegDataSample
from mmseg.utils import SampleList
@MODELS.register_module()
class SegTTAModel(BaseTTAModel):
def merge_preds(self, data_samples_list: List[SampleList]) -> SampleList:
"""Merge predictions of enhanced data to one prediction.
Args:
data_samples_list (List[SampleList]): List of predictions
of all enhanced data.
Returns:
SampleList: Merged prediction.
"""
predictions = []
for data_samples in data_samples_list:
seg_logits = data_samples[0].seg_logits.data
logits = torch.zeros(seg_logits.shape).to(seg_logits)
for data_sample in data_samples:
seg_logit = data_sample.seg_logits.data
if self.module.out_channels > 1:
logits += seg_logit.softmax(dim=0)
else:
logits += seg_logit.sigmoid()
logits /= len(data_samples)
if self.module.out_channels == 1:
seg_pred = (logits > self.module.decode_head.threshold
).to(logits).squeeze(1)
else:
seg_pred = logits.argmax(dim=0)
data_sample = SegDataSample(
**{
'pred_sem_seg': PixelData(data=seg_pred),
'gt_sem_seg': data_samples[0].gt_sem_seg
})
predictions.append(data_sample)
return predictions

View File

@ -1,151 +1,131 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# import os.path as osp import os.path as osp
# import mmcv import mmcv
# import pytest import pytest
# from mmseg.datasets.transforms import * # noqa from mmseg.datasets.transforms import * # noqa
# from mmseg.registry import TRANSFORMS from mmseg.registry import TRANSFORMS
# TODO
# def test_multi_scale_flip_aug():
# # test assertion if scales=None, scale_factor=1 (not float).
# with pytest.raises(AssertionError):
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scales=None,
# scale_factor=1,
# transforms=[dict(type='Resize', keep_ratio=False)],
# )
# TRANSFORMS.build(tta_transform)
# # test assertion if scales=None, scale_factor=None. def test_multi_scale_flip_aug():
# with pytest.raises(AssertionError): # test exception
# tta_transform = dict( with pytest.raises(TypeError):
# type='MultiScaleFlipAug', tta_transform = dict(
# scales=None, type='TestTimeAug',
# scale_factor=None, transforms=[dict(type='Resize', keep_ratio=False)],
# transforms=[dict(type='Resize', keep_ratio=False)], )
# ) TRANSFORMS.build(tta_transform)
# TRANSFORMS.build(tta_transform)
# # test assertion if scales=(512, 512), scale_factor=1 (not float). tta_transform = dict(
# with pytest.raises(AssertionError): type='TestTimeAug',
# tta_transform = dict( transforms=[[
# type='MultiScaleFlipAug', dict(type='Resize', scale=scale, keep_ratio=False)
# scales=(512, 512), for scale in [(256, 256), (512, 512), (1024, 1024)]
# scale_factor=1, ], [dict(type='mmseg.PackSegInputs')]])
# transforms=[dict(type='Resize', keep_ratio=False)], tta_module = TRANSFORMS.build(tta_transform)
# )
# TRANSFORMS.build(tta_transform)
# meta_keys = ('img', 'ori_shape', 'ori_height', 'ori_width', 'pad_shape',
# 'scale_factor', 'scale', 'flip')
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scales=[(256, 256), (512, 512), (1024, 1024)],
# allow_flip=False,
# resize_cfg=dict(type='Resize', keep_ratio=False),
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
# )
# tta_module = TRANSFORMS.build(tta_transform)
# results = dict() results = dict()
# # (288, 512, 3) # (288, 512, 3)
# img = mmcv.imread( img = mmcv.imread(
# osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
# results['img'] = img results['img'] = img
# results['ori_shape'] = img.shape results['ori_shape'] = img.shape
# results['ori_height'] = img.shape[0] results['ori_height'] = img.shape[0]
# results['ori_width'] = img.shape[1] results['ori_width'] = img.shape[1]
# # Set initial values for default meta_keys # Set initial values for default meta_keys
# results['pad_shape'] = img.shape results['pad_shape'] = img.shape
# results['scale_factor'] = 1.0 results['scale_factor'] = 1.0
# tta_results = tta_module(results.copy()) tta_results = tta_module(results.copy())
# assert [data_sample.scale assert [img.shape for img in tta_results['inputs']] == [(3, 256, 256),
# for data_sample in tta_results['data_sample']] == [(256, 256), (3, 512, 512),
# (512, 512), (3, 1024, 1024)]
# (1024, 1024)]
# assert [data_sample.flip for data_sample in tta_results['data_sample']
# ] == [False, False, False]
# tta_transform = dict( tta_transform = dict(
# type='MultiScaleFlipAug', type='TestTimeAug',
# scales=[(256, 256), (512, 512), (1024, 1024)], transforms=[
# allow_flip=True, [
# resize_cfg=dict(type='Resize', keep_ratio=False), dict(type='Resize', scale=scale, keep_ratio=False)
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], for scale in [(256, 256), (512, 512), (1024, 1024)]
# ) ],
# tta_module = TRANSFORMS.build(tta_transform) [
# tta_results = tta_module(results.copy()) dict(type='RandomFlip', prob=0., direction='horizontal'),
# assert [data_sample.scale dict(type='RandomFlip', prob=1., direction='horizontal')
# for data_sample in tta_results['data_sample']] == [(256, 256), ], [dict(type='mmseg.PackSegInputs')]
# (256, 256), ])
# (512, 512), tta_module = TRANSFORMS.build(tta_transform)
# (512, 512), tta_results: dict = tta_module(results.copy())
# (1024, 1024), assert [img.shape for img in tta_results['inputs']] == [(3, 256, 256),
# (1024, 1024)] (3, 256, 256),
# assert [data_sample.flip for data_sample in tta_results['data_sample'] (3, 512, 512),
# ] == [False, True, False, True, False, True] (3, 512, 512),
(3, 1024, 1024),
(3, 1024, 1024)]
assert [
data_sample.metainfo['flip']
for data_sample in tta_results['data_samples']
] == [False, True, False, True, False, True]
# tta_transform = dict( tta_transform = dict(
# type='MultiScaleFlipAug', type='TestTimeAug',
# scales=[(512, 512)], transforms=[[dict(type='Resize', scale=(512, 512), keep_ratio=False)],
# allow_flip=False, [dict(type='mmseg.PackSegInputs')]])
# resize_cfg=dict(type='Resize', keep_ratio=False), tta_module = TRANSFORMS.build(tta_transform)
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], tta_results = tta_module(results.copy())
# ) assert [tta_results['inputs'][0].shape] == [(3, 512, 512)]
# tta_module = TRANSFORMS.build(tta_transform)
# tta_results = tta_module(results.copy())
# assert [tta_results['data_sample'][0].scale] == [(512, 512)]
# assert [tta_results['data_sample'][0].flip] == [False]
# tta_transform = dict( tta_transform = dict(
# type='MultiScaleFlipAug', type='TestTimeAug',
# scales=[(512, 512)], transforms=[
# allow_flip=True, [dict(type='Resize', scale=(512, 512), keep_ratio=False)],
# resize_cfg=dict(type='Resize', keep_ratio=False), [
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], dict(type='RandomFlip', prob=0., direction='horizontal'),
# ) dict(type='RandomFlip', prob=1., direction='horizontal')
# tta_module = TRANSFORMS.build(tta_transform) ], [dict(type='mmseg.PackSegInputs')]
# tta_results = tta_module(results.copy()) ])
# assert [data_sample.scale tta_module = TRANSFORMS.build(tta_transform)
# for data_sample in tta_results['data_sample']] == [(512, 512), tta_results = tta_module(results.copy())
# (512, 512)] assert [img.shape for img in tta_results['inputs']] == [(3, 512, 512),
# assert [data_sample.flip (3, 512, 512)]
# for data_sample in tta_results['data_sample']] == [False, True] assert [
data_sample.metainfo['flip']
for data_sample in tta_results['data_samples']
] == [False, True]
# tta_transform = dict( tta_transform = dict(
# type='MultiScaleFlipAug', type='TestTimeAug',
# scale_factor=[0.5, 1.0, 2.0], transforms=[[
# allow_flip=False, dict(type='Resize', scale_factor=r, keep_ratio=False)
# resize_cfg=dict(type='Resize', keep_ratio=False), for r in [0.5, 1.0, 2.0]
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], ], [dict(type='mmseg.PackSegInputs')]])
# ) tta_module = TRANSFORMS.build(tta_transform)
# tta_module = TRANSFORMS.build(tta_transform) tta_results = tta_module(results.copy())
# tta_results = tta_module(results.copy()) assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256),
# assert [data_sample.scale (3, 288, 512),
# for data_sample in tta_results['data_sample']] == [(256, 144), (3, 576, 1024)]
# (512, 288),
# (1024, 576)]
# assert [data_sample.flip for data_sample in tta_results['data_sample']
# ] == [False, False, False]
# tta_transform = dict( tta_transform = dict(
# type='MultiScaleFlipAug', type='TestTimeAug',
# scale_factor=[0.5, 1.0, 2.0], transforms=[
# allow_flip=True, [
# resize_cfg=dict(type='Resize', keep_ratio=False), dict(type='Resize', scale_factor=r, keep_ratio=True)
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], for r in [0.5, 1.0, 2.0]
# ) ],
# tta_module = TRANSFORMS.build(tta_transform) [
# tta_results = tta_module(results.copy()) dict(type='RandomFlip', prob=0., direction='horizontal'),
# assert [data_sample.scale dict(type='RandomFlip', prob=1., direction='horizontal')
# for data_sample in tta_results['data_sample']] == [(256, 144), ], [dict(type='mmseg.PackSegInputs')]
# (256, 144), ])
# (512, 288), tta_module = TRANSFORMS.build(tta_transform)
# (512, 288), tta_results = tta_module(results.copy())
# (1024, 576), assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256),
# (1024, 576)] (3, 144, 256),
# assert [data_sample.flip for data_sample in tta_results['data_sample'] (3, 288, 512),
# ] == [False, True, False, True, False, True] (3, 288, 512),
(3, 576, 1024),
(3, 576, 1024)]
assert [
data_sample.metainfo['flip']
for data_sample in tta_results['data_samples']
] == [False, True, False, True, False, True]

View File

@ -0,0 +1,60 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine import ConfigDict
from mmengine.model import BaseTTAModel
from mmengine.structures import PixelData
from mmseg.registry import MODELS
from mmseg.structures import SegDataSample
from mmseg.utils import register_all_modules
from .utils import * # noqa: F401,F403
register_all_modules()
def test_encoder_decoder_tta():
segmentor_cfg = ConfigDict(
type='EncoderDecoder',
backbone=dict(type='ExampleBackbone'),
decode_head=dict(type='ExampleDecodeHead'),
train_cfg=None,
test_cfg=dict(mode='whole'))
cfg = ConfigDict(type='SegTTAModel', module=segmentor_cfg)
model: BaseTTAModel = MODELS.build(cfg)
imgs = []
data_samples = []
directions = ['horizontal', 'vertical']
for i in range(12):
flip_direction = directions[0] if i % 3 == 0 else directions[1]
imgs.append(torch.randn(1, 3, 10 + i, 10 + i))
data_samples.append([
SegDataSample(
metainfo=dict(
ori_shape=(10, 10),
img_shape=(10 + i, 10 + i),
flip=(i % 2 == 0),
flip_direction=flip_direction),
gt_sem_seg=PixelData(data=torch.randint(0, 19, (1, 10, 10))))
])
model.test_step(dict(inputs=imgs, data_samples=data_samples))
# test out_channels == 1
segmentor_cfg = ConfigDict(
type='EncoderDecoder',
backbone=dict(type='ExampleBackbone'),
decode_head=dict(
type='ExampleDecodeHead',
num_classes=2,
out_channels=1,
threshold=0.4),
train_cfg=None,
test_cfg=dict(mode='whole'))
model.module = MODELS.build(segmentor_cfg)
for data_sample in data_samples:
data_sample[0].gt_sem_seg.data = torch.randint(0, 2, (1, 10, 10))
model.test_step(dict(inputs=imgs, data_samples=data_samples))

View File

@ -66,9 +66,9 @@ class ExampleBackbone(nn.Module):
@MODELS.register_module() @MODELS.register_module()
class ExampleDecodeHead(BaseDecodeHead): class ExampleDecodeHead(BaseDecodeHead):
def __init__(self, num_classes=19, out_channels=None): def __init__(self, num_classes=19, out_channels=None, **kwargs):
super().__init__( super().__init__(
3, 3, num_classes=num_classes, out_channels=out_channels) 3, 3, num_classes=num_classes, out_channels=out_channels, **kwargs)
def forward(self, inputs): def forward(self, inputs):
return self.cls_seg(inputs[0]) return self.cls_seg(inputs[0])

View File

@ -43,6 +43,8 @@ def parse_args():
choices=['none', 'pytorch', 'slurm', 'mpi'], choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none', default='none',
help='job launcher') help='job launcher')
parser.add_argument(
'--tta', action='store_true', help='Test time augmentation')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ: if 'LOCAL_RANK' not in os.environ:
@ -99,6 +101,11 @@ def main():
if args.show or args.show_dir: if args.show or args.show_dir:
cfg = trigger_visualization_hook(cfg, args) cfg = trigger_visualization_hook(cfg, args)
if args.tta:
cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
cfg.tta_model.module = cfg.model
cfg.model = cfg.tta_model
# build the runner from config # build the runner from config
runner = Runner.from_cfg(cfg) runner = Runner.from_cfg(cfg)