mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Refactor] Support TTA (#2184)
* tta init * use mmcv transform * test city * add multiscale * fix merge * add softmax to post process * add ut * add tta pipeline to other datasets * remove softmax * add encoder_decoder_tta ut * add encoder_decoder_tta ut * rename * rename file * rename config * rm aug_test * move flip to post process * fix channel
This commit is contained in:
parent
20a6c58478
commit
da4125587e
@ -23,6 +23,22 @@ test_pipeline = [
|
|||||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||||
dict(type='PackSegInputs')
|
dict(type='PackSegInputs')
|
||||||
]
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||||
|
tta_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
train_dataloader = dict(
|
train_dataloader = dict(
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
|
@ -23,6 +23,22 @@ test_pipeline = [
|
|||||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||||
dict(type='PackSegInputs')
|
dict(type='PackSegInputs')
|
||||||
]
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||||
|
tta_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
train_dataloader = dict(
|
train_dataloader = dict(
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
|
@ -24,6 +24,22 @@ test_pipeline = [
|
|||||||
dict(type='LoadAnnotations'),
|
dict(type='LoadAnnotations'),
|
||||||
dict(type='PackSegInputs')
|
dict(type='PackSegInputs')
|
||||||
]
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||||
|
tta_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
|
|
||||||
train_dataloader = dict(
|
train_dataloader = dict(
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
|
@ -23,6 +23,22 @@ test_pipeline = [
|
|||||||
dict(type='LoadAnnotations'),
|
dict(type='LoadAnnotations'),
|
||||||
dict(type='PackSegInputs')
|
dict(type='PackSegInputs')
|
||||||
]
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||||
|
tta_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
train_dataloader = dict(
|
train_dataloader = dict(
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
num_workers=2,
|
num_workers=2,
|
||||||
|
@ -23,6 +23,22 @@ test_pipeline = [
|
|||||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||||
dict(type='PackSegInputs')
|
dict(type='PackSegInputs')
|
||||||
]
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||||
|
tta_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
train_dataloader = dict(
|
train_dataloader = dict(
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
|
@ -23,6 +23,22 @@ test_pipeline = [
|
|||||||
dict(type='LoadAnnotations'),
|
dict(type='LoadAnnotations'),
|
||||||
dict(type='PackSegInputs')
|
dict(type='PackSegInputs')
|
||||||
]
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||||
|
tta_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
train_dataloader = dict(
|
train_dataloader = dict(
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
|
@ -24,6 +24,22 @@ test_pipeline = [
|
|||||||
dict(type='LoadAnnotations'),
|
dict(type='LoadAnnotations'),
|
||||||
dict(type='PackSegInputs')
|
dict(type='PackSegInputs')
|
||||||
]
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||||
|
tta_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
train_dataloader = dict(
|
train_dataloader = dict(
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
|
@ -24,6 +24,22 @@ test_pipeline = [
|
|||||||
dict(type='LoadAnnotations'),
|
dict(type='LoadAnnotations'),
|
||||||
dict(type='PackSegInputs')
|
dict(type='PackSegInputs')
|
||||||
]
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||||
|
tta_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
train_dataloader = dict(
|
train_dataloader = dict(
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
|
@ -30,6 +30,22 @@ test_pipeline = [
|
|||||||
dict(type='LoadAnnotations'),
|
dict(type='LoadAnnotations'),
|
||||||
dict(type='PackSegInputs')
|
dict(type='PackSegInputs')
|
||||||
]
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||||
|
tta_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
train_dataloader = dict(
|
train_dataloader = dict(
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
|
@ -23,6 +23,22 @@ test_pipeline = [
|
|||||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||||
dict(type='PackSegInputs')
|
dict(type='PackSegInputs')
|
||||||
]
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||||
|
tta_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
train_dataloader = dict(
|
train_dataloader = dict(
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
|
@ -26,6 +26,22 @@ test_pipeline = [
|
|||||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||||
dict(type='PackSegInputs')
|
dict(type='PackSegInputs')
|
||||||
]
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||||
|
tta_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
train_dataloader = dict(
|
train_dataloader = dict(
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
|
@ -23,6 +23,22 @@ test_pipeline = [
|
|||||||
dict(type='LoadAnnotations'),
|
dict(type='LoadAnnotations'),
|
||||||
dict(type='PackSegInputs')
|
dict(type='PackSegInputs')
|
||||||
]
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||||
|
tta_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
train_dataloader = dict(
|
train_dataloader = dict(
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
|
@ -25,7 +25,22 @@ test_pipeline = [
|
|||||||
dict(type='LoadAnnotations'),
|
dict(type='LoadAnnotations'),
|
||||||
dict(type='PackSegInputs')
|
dict(type='PackSegInputs')
|
||||||
]
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||||
|
tta_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
dataset_train = dict(
|
dataset_train = dict(
|
||||||
type=dataset_type,
|
type=dataset_type,
|
||||||
data_root=data_root,
|
data_root=data_root,
|
||||||
|
@ -23,6 +23,22 @@ test_pipeline = [
|
|||||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||||
dict(type='PackSegInputs')
|
dict(type='PackSegInputs')
|
||||||
]
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||||
|
tta_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
train_dataloader = dict(
|
train_dataloader = dict(
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
|
@ -24,6 +24,22 @@ test_pipeline = [
|
|||||||
dict(type='LoadAnnotations'),
|
dict(type='LoadAnnotations'),
|
||||||
dict(type='PackSegInputs')
|
dict(type='PackSegInputs')
|
||||||
]
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||||
|
tta_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
train_dataloader = dict(
|
train_dataloader = dict(
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
|
@ -23,6 +23,22 @@ test_pipeline = [
|
|||||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||||
dict(type='PackSegInputs')
|
dict(type='PackSegInputs')
|
||||||
]
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||||
|
tta_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
train_dataloader = dict(
|
train_dataloader = dict(
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
|
@ -11,3 +11,5 @@ log_processor = dict(by_epoch=False)
|
|||||||
log_level = 'INFO'
|
log_level = 'INFO'
|
||||||
load_from = None
|
load_from = None
|
||||||
resume = False
|
resume = False
|
||||||
|
|
||||||
|
tta_model = dict(type='SegTTAModel')
|
||||||
|
@ -2,5 +2,8 @@
|
|||||||
from .base import BaseSegmentor
|
from .base import BaseSegmentor
|
||||||
from .cascade_encoder_decoder import CascadeEncoderDecoder
|
from .cascade_encoder_decoder import CascadeEncoderDecoder
|
||||||
from .encoder_decoder import EncoderDecoder
|
from .encoder_decoder import EncoderDecoder
|
||||||
|
from .seg_tta import SegTTAModel
|
||||||
|
|
||||||
__all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder']
|
__all__ = [
|
||||||
|
'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel'
|
||||||
|
]
|
||||||
|
@ -124,11 +124,6 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def aug_test(self, batch_inputs, batch_img_metas):
|
|
||||||
"""Placeholder for augmentation test."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def postprocess_result(self,
|
def postprocess_result(self,
|
||||||
seg_logits: Tensor,
|
seg_logits: Tensor,
|
||||||
data_samples: OptSampleList = None) -> list:
|
data_samples: OptSampleList = None) -> list:
|
||||||
@ -170,6 +165,15 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
|||||||
padding_top:H - padding_bottom,
|
padding_top:H - padding_bottom,
|
||||||
padding_left:W - padding_right]
|
padding_left:W - padding_right]
|
||||||
|
|
||||||
|
flip = img_meta.get('flip', None)
|
||||||
|
if flip:
|
||||||
|
flip_direction = img_meta.get('flip_direction', None)
|
||||||
|
assert flip_direction in ['horizontal', 'vertical']
|
||||||
|
if flip_direction == 'horizontal':
|
||||||
|
i_seg_logits = i_seg_logits.flip(dims=(3, ))
|
||||||
|
else:
|
||||||
|
i_seg_logits = i_seg_logits.flip(dims=(2, ))
|
||||||
|
|
||||||
# resize as original shape
|
# resize as original shape
|
||||||
i_seg_logits = resize(
|
i_seg_logits = resize(
|
||||||
i_seg_logits,
|
i_seg_logits,
|
||||||
|
48
mmseg/models/segmentors/seg_tta.py
Normal file
48
mmseg/models/segmentors/seg_tta.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from mmengine.model import BaseTTAModel
|
||||||
|
from mmengine.structures import PixelData
|
||||||
|
|
||||||
|
from mmseg.registry import MODELS
|
||||||
|
from mmseg.structures import SegDataSample
|
||||||
|
from mmseg.utils import SampleList
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class SegTTAModel(BaseTTAModel):
|
||||||
|
|
||||||
|
def merge_preds(self, data_samples_list: List[SampleList]) -> SampleList:
|
||||||
|
"""Merge predictions of enhanced data to one prediction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_samples_list (List[SampleList]): List of predictions
|
||||||
|
of all enhanced data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SampleList: Merged prediction.
|
||||||
|
"""
|
||||||
|
predictions = []
|
||||||
|
for data_samples in data_samples_list:
|
||||||
|
seg_logits = data_samples[0].seg_logits.data
|
||||||
|
logits = torch.zeros(seg_logits.shape).to(seg_logits)
|
||||||
|
for data_sample in data_samples:
|
||||||
|
seg_logit = data_sample.seg_logits.data
|
||||||
|
if self.module.out_channels > 1:
|
||||||
|
logits += seg_logit.softmax(dim=0)
|
||||||
|
else:
|
||||||
|
logits += seg_logit.sigmoid()
|
||||||
|
logits /= len(data_samples)
|
||||||
|
if self.module.out_channels == 1:
|
||||||
|
seg_pred = (logits > self.module.decode_head.threshold
|
||||||
|
).to(logits).squeeze(1)
|
||||||
|
else:
|
||||||
|
seg_pred = logits.argmax(dim=0)
|
||||||
|
data_sample = SegDataSample(
|
||||||
|
**{
|
||||||
|
'pred_sem_seg': PixelData(data=seg_pred),
|
||||||
|
'gt_sem_seg': data_samples[0].gt_sem_seg
|
||||||
|
})
|
||||||
|
predictions.append(data_sample)
|
||||||
|
return predictions
|
@ -1,151 +1,131 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
# import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
# import mmcv
|
import mmcv
|
||||||
# import pytest
|
import pytest
|
||||||
|
|
||||||
# from mmseg.datasets.transforms import * # noqa
|
from mmseg.datasets.transforms import * # noqa
|
||||||
# from mmseg.registry import TRANSFORMS
|
from mmseg.registry import TRANSFORMS
|
||||||
|
|
||||||
# TODO
|
|
||||||
# def test_multi_scale_flip_aug():
|
|
||||||
# # test assertion if scales=None, scale_factor=1 (not float).
|
|
||||||
# with pytest.raises(AssertionError):
|
|
||||||
# tta_transform = dict(
|
|
||||||
# type='MultiScaleFlipAug',
|
|
||||||
# scales=None,
|
|
||||||
# scale_factor=1,
|
|
||||||
# transforms=[dict(type='Resize', keep_ratio=False)],
|
|
||||||
# )
|
|
||||||
# TRANSFORMS.build(tta_transform)
|
|
||||||
|
|
||||||
# # test assertion if scales=None, scale_factor=None.
|
def test_multi_scale_flip_aug():
|
||||||
# with pytest.raises(AssertionError):
|
# test exception
|
||||||
# tta_transform = dict(
|
with pytest.raises(TypeError):
|
||||||
# type='MultiScaleFlipAug',
|
tta_transform = dict(
|
||||||
# scales=None,
|
type='TestTimeAug',
|
||||||
# scale_factor=None,
|
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||||
# transforms=[dict(type='Resize', keep_ratio=False)],
|
)
|
||||||
# )
|
TRANSFORMS.build(tta_transform)
|
||||||
# TRANSFORMS.build(tta_transform)
|
|
||||||
|
|
||||||
# # test assertion if scales=(512, 512), scale_factor=1 (not float).
|
tta_transform = dict(
|
||||||
# with pytest.raises(AssertionError):
|
type='TestTimeAug',
|
||||||
# tta_transform = dict(
|
transforms=[[
|
||||||
# type='MultiScaleFlipAug',
|
dict(type='Resize', scale=scale, keep_ratio=False)
|
||||||
# scales=(512, 512),
|
for scale in [(256, 256), (512, 512), (1024, 1024)]
|
||||||
# scale_factor=1,
|
], [dict(type='mmseg.PackSegInputs')]])
|
||||||
# transforms=[dict(type='Resize', keep_ratio=False)],
|
tta_module = TRANSFORMS.build(tta_transform)
|
||||||
# )
|
|
||||||
# TRANSFORMS.build(tta_transform)
|
|
||||||
# meta_keys = ('img', 'ori_shape', 'ori_height', 'ori_width', 'pad_shape',
|
|
||||||
# 'scale_factor', 'scale', 'flip')
|
|
||||||
# tta_transform = dict(
|
|
||||||
# type='MultiScaleFlipAug',
|
|
||||||
# scales=[(256, 256), (512, 512), (1024, 1024)],
|
|
||||||
# allow_flip=False,
|
|
||||||
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
|
||||||
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
|
||||||
# )
|
|
||||||
# tta_module = TRANSFORMS.build(tta_transform)
|
|
||||||
|
|
||||||
# results = dict()
|
results = dict()
|
||||||
# # (288, 512, 3)
|
# (288, 512, 3)
|
||||||
# img = mmcv.imread(
|
img = mmcv.imread(
|
||||||
# osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||||
# results['img'] = img
|
results['img'] = img
|
||||||
# results['ori_shape'] = img.shape
|
results['ori_shape'] = img.shape
|
||||||
# results['ori_height'] = img.shape[0]
|
results['ori_height'] = img.shape[0]
|
||||||
# results['ori_width'] = img.shape[1]
|
results['ori_width'] = img.shape[1]
|
||||||
# # Set initial values for default meta_keys
|
# Set initial values for default meta_keys
|
||||||
# results['pad_shape'] = img.shape
|
results['pad_shape'] = img.shape
|
||||||
# results['scale_factor'] = 1.0
|
results['scale_factor'] = 1.0
|
||||||
|
|
||||||
# tta_results = tta_module(results.copy())
|
tta_results = tta_module(results.copy())
|
||||||
# assert [data_sample.scale
|
assert [img.shape for img in tta_results['inputs']] == [(3, 256, 256),
|
||||||
# for data_sample in tta_results['data_sample']] == [(256, 256),
|
(3, 512, 512),
|
||||||
# (512, 512),
|
(3, 1024, 1024)]
|
||||||
# (1024, 1024)]
|
|
||||||
# assert [data_sample.flip for data_sample in tta_results['data_sample']
|
|
||||||
# ] == [False, False, False]
|
|
||||||
|
|
||||||
# tta_transform = dict(
|
tta_transform = dict(
|
||||||
# type='MultiScaleFlipAug',
|
type='TestTimeAug',
|
||||||
# scales=[(256, 256), (512, 512), (1024, 1024)],
|
transforms=[
|
||||||
# allow_flip=True,
|
[
|
||||||
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
dict(type='Resize', scale=scale, keep_ratio=False)
|
||||||
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
for scale in [(256, 256), (512, 512), (1024, 1024)]
|
||||||
# )
|
],
|
||||||
# tta_module = TRANSFORMS.build(tta_transform)
|
[
|
||||||
# tta_results = tta_module(results.copy())
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
# assert [data_sample.scale
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
# for data_sample in tta_results['data_sample']] == [(256, 256),
|
], [dict(type='mmseg.PackSegInputs')]
|
||||||
# (256, 256),
|
])
|
||||||
# (512, 512),
|
tta_module = TRANSFORMS.build(tta_transform)
|
||||||
# (512, 512),
|
tta_results: dict = tta_module(results.copy())
|
||||||
# (1024, 1024),
|
assert [img.shape for img in tta_results['inputs']] == [(3, 256, 256),
|
||||||
# (1024, 1024)]
|
(3, 256, 256),
|
||||||
# assert [data_sample.flip for data_sample in tta_results['data_sample']
|
(3, 512, 512),
|
||||||
# ] == [False, True, False, True, False, True]
|
(3, 512, 512),
|
||||||
|
(3, 1024, 1024),
|
||||||
|
(3, 1024, 1024)]
|
||||||
|
assert [
|
||||||
|
data_sample.metainfo['flip']
|
||||||
|
for data_sample in tta_results['data_samples']
|
||||||
|
] == [False, True, False, True, False, True]
|
||||||
|
|
||||||
# tta_transform = dict(
|
tta_transform = dict(
|
||||||
# type='MultiScaleFlipAug',
|
type='TestTimeAug',
|
||||||
# scales=[(512, 512)],
|
transforms=[[dict(type='Resize', scale=(512, 512), keep_ratio=False)],
|
||||||
# allow_flip=False,
|
[dict(type='mmseg.PackSegInputs')]])
|
||||||
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
tta_module = TRANSFORMS.build(tta_transform)
|
||||||
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
tta_results = tta_module(results.copy())
|
||||||
# )
|
assert [tta_results['inputs'][0].shape] == [(3, 512, 512)]
|
||||||
# tta_module = TRANSFORMS.build(tta_transform)
|
|
||||||
# tta_results = tta_module(results.copy())
|
|
||||||
# assert [tta_results['data_sample'][0].scale] == [(512, 512)]
|
|
||||||
# assert [tta_results['data_sample'][0].flip] == [False]
|
|
||||||
|
|
||||||
# tta_transform = dict(
|
tta_transform = dict(
|
||||||
# type='MultiScaleFlipAug',
|
type='TestTimeAug',
|
||||||
# scales=[(512, 512)],
|
transforms=[
|
||||||
# allow_flip=True,
|
[dict(type='Resize', scale=(512, 512), keep_ratio=False)],
|
||||||
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
[
|
||||||
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
# )
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
# tta_module = TRANSFORMS.build(tta_transform)
|
], [dict(type='mmseg.PackSegInputs')]
|
||||||
# tta_results = tta_module(results.copy())
|
])
|
||||||
# assert [data_sample.scale
|
tta_module = TRANSFORMS.build(tta_transform)
|
||||||
# for data_sample in tta_results['data_sample']] == [(512, 512),
|
tta_results = tta_module(results.copy())
|
||||||
# (512, 512)]
|
assert [img.shape for img in tta_results['inputs']] == [(3, 512, 512),
|
||||||
# assert [data_sample.flip
|
(3, 512, 512)]
|
||||||
# for data_sample in tta_results['data_sample']] == [False, True]
|
assert [
|
||||||
|
data_sample.metainfo['flip']
|
||||||
|
for data_sample in tta_results['data_samples']
|
||||||
|
] == [False, True]
|
||||||
|
|
||||||
# tta_transform = dict(
|
tta_transform = dict(
|
||||||
# type='MultiScaleFlipAug',
|
type='TestTimeAug',
|
||||||
# scale_factor=[0.5, 1.0, 2.0],
|
transforms=[[
|
||||||
# allow_flip=False,
|
dict(type='Resize', scale_factor=r, keep_ratio=False)
|
||||||
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
for r in [0.5, 1.0, 2.0]
|
||||||
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
], [dict(type='mmseg.PackSegInputs')]])
|
||||||
# )
|
tta_module = TRANSFORMS.build(tta_transform)
|
||||||
# tta_module = TRANSFORMS.build(tta_transform)
|
tta_results = tta_module(results.copy())
|
||||||
# tta_results = tta_module(results.copy())
|
assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256),
|
||||||
# assert [data_sample.scale
|
(3, 288, 512),
|
||||||
# for data_sample in tta_results['data_sample']] == [(256, 144),
|
(3, 576, 1024)]
|
||||||
# (512, 288),
|
|
||||||
# (1024, 576)]
|
|
||||||
# assert [data_sample.flip for data_sample in tta_results['data_sample']
|
|
||||||
# ] == [False, False, False]
|
|
||||||
|
|
||||||
# tta_transform = dict(
|
tta_transform = dict(
|
||||||
# type='MultiScaleFlipAug',
|
type='TestTimeAug',
|
||||||
# scale_factor=[0.5, 1.0, 2.0],
|
transforms=[
|
||||||
# allow_flip=True,
|
[
|
||||||
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
for r in [0.5, 1.0, 2.0]
|
||||||
# )
|
],
|
||||||
# tta_module = TRANSFORMS.build(tta_transform)
|
[
|
||||||
# tta_results = tta_module(results.copy())
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
# assert [data_sample.scale
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
# for data_sample in tta_results['data_sample']] == [(256, 144),
|
], [dict(type='mmseg.PackSegInputs')]
|
||||||
# (256, 144),
|
])
|
||||||
# (512, 288),
|
tta_module = TRANSFORMS.build(tta_transform)
|
||||||
# (512, 288),
|
tta_results = tta_module(results.copy())
|
||||||
# (1024, 576),
|
assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256),
|
||||||
# (1024, 576)]
|
(3, 144, 256),
|
||||||
# assert [data_sample.flip for data_sample in tta_results['data_sample']
|
(3, 288, 512),
|
||||||
# ] == [False, True, False, True, False, True]
|
(3, 288, 512),
|
||||||
|
(3, 576, 1024),
|
||||||
|
(3, 576, 1024)]
|
||||||
|
assert [
|
||||||
|
data_sample.metainfo['flip']
|
||||||
|
for data_sample in tta_results['data_samples']
|
||||||
|
] == [False, True, False, True, False, True]
|
||||||
|
60
tests/test_models/test_segmentors/test_seg_tta_model.py
Normal file
60
tests/test_models/test_segmentors/test_seg_tta_model.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import torch
|
||||||
|
from mmengine import ConfigDict
|
||||||
|
from mmengine.model import BaseTTAModel
|
||||||
|
from mmengine.structures import PixelData
|
||||||
|
|
||||||
|
from mmseg.registry import MODELS
|
||||||
|
from mmseg.structures import SegDataSample
|
||||||
|
from mmseg.utils import register_all_modules
|
||||||
|
from .utils import * # noqa: F401,F403
|
||||||
|
|
||||||
|
register_all_modules()
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_decoder_tta():
|
||||||
|
|
||||||
|
segmentor_cfg = ConfigDict(
|
||||||
|
type='EncoderDecoder',
|
||||||
|
backbone=dict(type='ExampleBackbone'),
|
||||||
|
decode_head=dict(type='ExampleDecodeHead'),
|
||||||
|
train_cfg=None,
|
||||||
|
test_cfg=dict(mode='whole'))
|
||||||
|
|
||||||
|
cfg = ConfigDict(type='SegTTAModel', module=segmentor_cfg)
|
||||||
|
|
||||||
|
model: BaseTTAModel = MODELS.build(cfg)
|
||||||
|
|
||||||
|
imgs = []
|
||||||
|
data_samples = []
|
||||||
|
directions = ['horizontal', 'vertical']
|
||||||
|
for i in range(12):
|
||||||
|
flip_direction = directions[0] if i % 3 == 0 else directions[1]
|
||||||
|
imgs.append(torch.randn(1, 3, 10 + i, 10 + i))
|
||||||
|
data_samples.append([
|
||||||
|
SegDataSample(
|
||||||
|
metainfo=dict(
|
||||||
|
ori_shape=(10, 10),
|
||||||
|
img_shape=(10 + i, 10 + i),
|
||||||
|
flip=(i % 2 == 0),
|
||||||
|
flip_direction=flip_direction),
|
||||||
|
gt_sem_seg=PixelData(data=torch.randint(0, 19, (1, 10, 10))))
|
||||||
|
])
|
||||||
|
|
||||||
|
model.test_step(dict(inputs=imgs, data_samples=data_samples))
|
||||||
|
|
||||||
|
# test out_channels == 1
|
||||||
|
segmentor_cfg = ConfigDict(
|
||||||
|
type='EncoderDecoder',
|
||||||
|
backbone=dict(type='ExampleBackbone'),
|
||||||
|
decode_head=dict(
|
||||||
|
type='ExampleDecodeHead',
|
||||||
|
num_classes=2,
|
||||||
|
out_channels=1,
|
||||||
|
threshold=0.4),
|
||||||
|
train_cfg=None,
|
||||||
|
test_cfg=dict(mode='whole'))
|
||||||
|
model.module = MODELS.build(segmentor_cfg)
|
||||||
|
for data_sample in data_samples:
|
||||||
|
data_sample[0].gt_sem_seg.data = torch.randint(0, 2, (1, 10, 10))
|
||||||
|
model.test_step(dict(inputs=imgs, data_samples=data_samples))
|
@ -66,9 +66,9 @@ class ExampleBackbone(nn.Module):
|
|||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class ExampleDecodeHead(BaseDecodeHead):
|
class ExampleDecodeHead(BaseDecodeHead):
|
||||||
|
|
||||||
def __init__(self, num_classes=19, out_channels=None):
|
def __init__(self, num_classes=19, out_channels=None, **kwargs):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
3, 3, num_classes=num_classes, out_channels=out_channels)
|
3, 3, num_classes=num_classes, out_channels=out_channels, **kwargs)
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
return self.cls_seg(inputs[0])
|
return self.cls_seg(inputs[0])
|
||||||
|
@ -43,6 +43,8 @@ def parse_args():
|
|||||||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||||
default='none',
|
default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
|
parser.add_argument(
|
||||||
|
'--tta', action='store_true', help='Test time augmentation')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if 'LOCAL_RANK' not in os.environ:
|
if 'LOCAL_RANK' not in os.environ:
|
||||||
@ -99,6 +101,11 @@ def main():
|
|||||||
if args.show or args.show_dir:
|
if args.show or args.show_dir:
|
||||||
cfg = trigger_visualization_hook(cfg, args)
|
cfg = trigger_visualization_hook(cfg, args)
|
||||||
|
|
||||||
|
if args.tta:
|
||||||
|
cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
|
||||||
|
cfg.tta_model.module = cfg.model
|
||||||
|
cfg.model = cfg.tta_model
|
||||||
|
|
||||||
# build the runner from config
|
# build the runner from config
|
||||||
runner = Runner.from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user