[Refactor] Support TTA (#2184)
* tta init * use mmcv transform * test city * add multiscale * fix merge * add softmax to post process * add ut * add tta pipeline to other datasets * remove softmax * add encoder_decoder_tta ut * add encoder_decoder_tta ut * rename * rename file * rename config * rm aug_test * move flip to post process * fix channelpull/2448/head
parent
20a6c58478
commit
da4125587e
|
@ -23,6 +23,22 @@ test_pipeline = [
|
|||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
|
|
|
@ -23,6 +23,22 @@ test_pipeline = [
|
|||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
|
|
|
@ -24,6 +24,22 @@ test_pipeline = [
|
|||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
|
|
|
@ -23,6 +23,22 @@ test_pipeline = [
|
|||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=2,
|
||||
num_workers=2,
|
||||
|
|
|
@ -23,6 +23,22 @@ test_pipeline = [
|
|||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
|
|
|
@ -23,6 +23,22 @@ test_pipeline = [
|
|||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
|
|
|
@ -24,6 +24,22 @@ test_pipeline = [
|
|||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
|
|
|
@ -24,6 +24,22 @@ test_pipeline = [
|
|||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
|
|
|
@ -30,6 +30,22 @@ test_pipeline = [
|
|||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
|
|
|
@ -23,6 +23,22 @@ test_pipeline = [
|
|||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
|
|
|
@ -26,6 +26,22 @@ test_pipeline = [
|
|||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
|
|
|
@ -23,6 +23,22 @@ test_pipeline = [
|
|||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
|
|
|
@ -25,7 +25,22 @@ test_pipeline = [
|
|||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
dataset_train = dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
|
|
|
@ -23,6 +23,22 @@ test_pipeline = [
|
|||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
|
|
|
@ -24,6 +24,22 @@ test_pipeline = [
|
|||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
|
|
|
@ -23,6 +23,22 @@ test_pipeline = [
|
|||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
|
|
|
@ -11,3 +11,5 @@ log_processor = dict(by_epoch=False)
|
|||
log_level = 'INFO'
|
||||
load_from = None
|
||||
resume = False
|
||||
|
||||
tta_model = dict(type='SegTTAModel')
|
||||
|
|
|
@ -2,5 +2,8 @@
|
|||
from .base import BaseSegmentor
|
||||
from .cascade_encoder_decoder import CascadeEncoderDecoder
|
||||
from .encoder_decoder import EncoderDecoder
|
||||
from .seg_tta import SegTTAModel
|
||||
|
||||
__all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder']
|
||||
__all__ = [
|
||||
'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel'
|
||||
]
|
||||
|
|
|
@ -124,11 +124,6 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
|||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def aug_test(self, batch_inputs, batch_img_metas):
|
||||
"""Placeholder for augmentation test."""
|
||||
pass
|
||||
|
||||
def postprocess_result(self,
|
||||
seg_logits: Tensor,
|
||||
data_samples: OptSampleList = None) -> list:
|
||||
|
@ -170,6 +165,15 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
|||
padding_top:H - padding_bottom,
|
||||
padding_left:W - padding_right]
|
||||
|
||||
flip = img_meta.get('flip', None)
|
||||
if flip:
|
||||
flip_direction = img_meta.get('flip_direction', None)
|
||||
assert flip_direction in ['horizontal', 'vertical']
|
||||
if flip_direction == 'horizontal':
|
||||
i_seg_logits = i_seg_logits.flip(dims=(3, ))
|
||||
else:
|
||||
i_seg_logits = i_seg_logits.flip(dims=(2, ))
|
||||
|
||||
# resize as original shape
|
||||
i_seg_logits = resize(
|
||||
i_seg_logits,
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseTTAModel
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SegTTAModel(BaseTTAModel):
|
||||
|
||||
def merge_preds(self, data_samples_list: List[SampleList]) -> SampleList:
|
||||
"""Merge predictions of enhanced data to one prediction.
|
||||
|
||||
Args:
|
||||
data_samples_list (List[SampleList]): List of predictions
|
||||
of all enhanced data.
|
||||
|
||||
Returns:
|
||||
SampleList: Merged prediction.
|
||||
"""
|
||||
predictions = []
|
||||
for data_samples in data_samples_list:
|
||||
seg_logits = data_samples[0].seg_logits.data
|
||||
logits = torch.zeros(seg_logits.shape).to(seg_logits)
|
||||
for data_sample in data_samples:
|
||||
seg_logit = data_sample.seg_logits.data
|
||||
if self.module.out_channels > 1:
|
||||
logits += seg_logit.softmax(dim=0)
|
||||
else:
|
||||
logits += seg_logit.sigmoid()
|
||||
logits /= len(data_samples)
|
||||
if self.module.out_channels == 1:
|
||||
seg_pred = (logits > self.module.decode_head.threshold
|
||||
).to(logits).squeeze(1)
|
||||
else:
|
||||
seg_pred = logits.argmax(dim=0)
|
||||
data_sample = SegDataSample(
|
||||
**{
|
||||
'pred_sem_seg': PixelData(data=seg_pred),
|
||||
'gt_sem_seg': data_samples[0].gt_sem_seg
|
||||
})
|
||||
predictions.append(data_sample)
|
||||
return predictions
|
|
@ -1,151 +1,131 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# import os.path as osp
|
||||
import os.path as osp
|
||||
|
||||
# import mmcv
|
||||
# import pytest
|
||||
import mmcv
|
||||
import pytest
|
||||
|
||||
# from mmseg.datasets.transforms import * # noqa
|
||||
# from mmseg.registry import TRANSFORMS
|
||||
from mmseg.datasets.transforms import * # noqa
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
||||
# TODO
|
||||
# def test_multi_scale_flip_aug():
|
||||
# # test assertion if scales=None, scale_factor=1 (not float).
|
||||
# with pytest.raises(AssertionError):
|
||||
# tta_transform = dict(
|
||||
# type='MultiScaleFlipAug',
|
||||
# scales=None,
|
||||
# scale_factor=1,
|
||||
# transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
# )
|
||||
# TRANSFORMS.build(tta_transform)
|
||||
|
||||
# # test assertion if scales=None, scale_factor=None.
|
||||
# with pytest.raises(AssertionError):
|
||||
# tta_transform = dict(
|
||||
# type='MultiScaleFlipAug',
|
||||
# scales=None,
|
||||
# scale_factor=None,
|
||||
# transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
# )
|
||||
# TRANSFORMS.build(tta_transform)
|
||||
def test_multi_scale_flip_aug():
|
||||
# test exception
|
||||
with pytest.raises(TypeError):
|
||||
tta_transform = dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
TRANSFORMS.build(tta_transform)
|
||||
|
||||
# # test assertion if scales=(512, 512), scale_factor=1 (not float).
|
||||
# with pytest.raises(AssertionError):
|
||||
# tta_transform = dict(
|
||||
# type='MultiScaleFlipAug',
|
||||
# scales=(512, 512),
|
||||
# scale_factor=1,
|
||||
# transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
# )
|
||||
# TRANSFORMS.build(tta_transform)
|
||||
# meta_keys = ('img', 'ori_shape', 'ori_height', 'ori_width', 'pad_shape',
|
||||
# 'scale_factor', 'scale', 'flip')
|
||||
# tta_transform = dict(
|
||||
# type='MultiScaleFlipAug',
|
||||
# scales=[(256, 256), (512, 512), (1024, 1024)],
|
||||
# allow_flip=False,
|
||||
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
# )
|
||||
# tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_transform = dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[[
|
||||
dict(type='Resize', scale=scale, keep_ratio=False)
|
||||
for scale in [(256, 256), (512, 512), (1024, 1024)]
|
||||
], [dict(type='mmseg.PackSegInputs')]])
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
|
||||
# results = dict()
|
||||
# # (288, 512, 3)
|
||||
# img = mmcv.imread(
|
||||
# osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
# results['img'] = img
|
||||
# results['ori_shape'] = img.shape
|
||||
# results['ori_height'] = img.shape[0]
|
||||
# results['ori_width'] = img.shape[1]
|
||||
# # Set initial values for default meta_keys
|
||||
# results['pad_shape'] = img.shape
|
||||
# results['scale_factor'] = 1.0
|
||||
results = dict()
|
||||
# (288, 512, 3)
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
results['img'] = img
|
||||
results['ori_shape'] = img.shape
|
||||
results['ori_height'] = img.shape[0]
|
||||
results['ori_width'] = img.shape[1]
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
# tta_results = tta_module(results.copy())
|
||||
# assert [data_sample.scale
|
||||
# for data_sample in tta_results['data_sample']] == [(256, 256),
|
||||
# (512, 512),
|
||||
# (1024, 1024)]
|
||||
# assert [data_sample.flip for data_sample in tta_results['data_sample']
|
||||
# ] == [False, False, False]
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [img.shape for img in tta_results['inputs']] == [(3, 256, 256),
|
||||
(3, 512, 512),
|
||||
(3, 1024, 1024)]
|
||||
|
||||
# tta_transform = dict(
|
||||
# type='MultiScaleFlipAug',
|
||||
# scales=[(256, 256), (512, 512), (1024, 1024)],
|
||||
# allow_flip=True,
|
||||
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
# )
|
||||
# tta_module = TRANSFORMS.build(tta_transform)
|
||||
# tta_results = tta_module(results.copy())
|
||||
# assert [data_sample.scale
|
||||
# for data_sample in tta_results['data_sample']] == [(256, 256),
|
||||
# (256, 256),
|
||||
# (512, 512),
|
||||
# (512, 512),
|
||||
# (1024, 1024),
|
||||
# (1024, 1024)]
|
||||
# assert [data_sample.flip for data_sample in tta_results['data_sample']
|
||||
# ] == [False, True, False, True, False, True]
|
||||
tta_transform = dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale=scale, keep_ratio=False)
|
||||
for scale in [(256, 256), (512, 512), (1024, 1024)]
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='mmseg.PackSegInputs')]
|
||||
])
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results: dict = tta_module(results.copy())
|
||||
assert [img.shape for img in tta_results['inputs']] == [(3, 256, 256),
|
||||
(3, 256, 256),
|
||||
(3, 512, 512),
|
||||
(3, 512, 512),
|
||||
(3, 1024, 1024),
|
||||
(3, 1024, 1024)]
|
||||
assert [
|
||||
data_sample.metainfo['flip']
|
||||
for data_sample in tta_results['data_samples']
|
||||
] == [False, True, False, True, False, True]
|
||||
|
||||
# tta_transform = dict(
|
||||
# type='MultiScaleFlipAug',
|
||||
# scales=[(512, 512)],
|
||||
# allow_flip=False,
|
||||
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
# )
|
||||
# tta_module = TRANSFORMS.build(tta_transform)
|
||||
# tta_results = tta_module(results.copy())
|
||||
# assert [tta_results['data_sample'][0].scale] == [(512, 512)]
|
||||
# assert [tta_results['data_sample'][0].flip] == [False]
|
||||
tta_transform = dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[[dict(type='Resize', scale=(512, 512), keep_ratio=False)],
|
||||
[dict(type='mmseg.PackSegInputs')]])
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [tta_results['inputs'][0].shape] == [(3, 512, 512)]
|
||||
|
||||
# tta_transform = dict(
|
||||
# type='MultiScaleFlipAug',
|
||||
# scales=[(512, 512)],
|
||||
# allow_flip=True,
|
||||
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
# )
|
||||
# tta_module = TRANSFORMS.build(tta_transform)
|
||||
# tta_results = tta_module(results.copy())
|
||||
# assert [data_sample.scale
|
||||
# for data_sample in tta_results['data_sample']] == [(512, 512),
|
||||
# (512, 512)]
|
||||
# assert [data_sample.flip
|
||||
# for data_sample in tta_results['data_sample']] == [False, True]
|
||||
tta_transform = dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[dict(type='Resize', scale=(512, 512), keep_ratio=False)],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='mmseg.PackSegInputs')]
|
||||
])
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [img.shape for img in tta_results['inputs']] == [(3, 512, 512),
|
||||
(3, 512, 512)]
|
||||
assert [
|
||||
data_sample.metainfo['flip']
|
||||
for data_sample in tta_results['data_samples']
|
||||
] == [False, True]
|
||||
|
||||
# tta_transform = dict(
|
||||
# type='MultiScaleFlipAug',
|
||||
# scale_factor=[0.5, 1.0, 2.0],
|
||||
# allow_flip=False,
|
||||
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
# )
|
||||
# tta_module = TRANSFORMS.build(tta_transform)
|
||||
# tta_results = tta_module(results.copy())
|
||||
# assert [data_sample.scale
|
||||
# for data_sample in tta_results['data_sample']] == [(256, 144),
|
||||
# (512, 288),
|
||||
# (1024, 576)]
|
||||
# assert [data_sample.flip for data_sample in tta_results['data_sample']
|
||||
# ] == [False, False, False]
|
||||
tta_transform = dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=False)
|
||||
for r in [0.5, 1.0, 2.0]
|
||||
], [dict(type='mmseg.PackSegInputs')]])
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256),
|
||||
(3, 288, 512),
|
||||
(3, 576, 1024)]
|
||||
|
||||
# tta_transform = dict(
|
||||
# type='MultiScaleFlipAug',
|
||||
# scale_factor=[0.5, 1.0, 2.0],
|
||||
# allow_flip=True,
|
||||
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
# )
|
||||
# tta_module = TRANSFORMS.build(tta_transform)
|
||||
# tta_results = tta_module(results.copy())
|
||||
# assert [data_sample.scale
|
||||
# for data_sample in tta_results['data_sample']] == [(256, 144),
|
||||
# (256, 144),
|
||||
# (512, 288),
|
||||
# (512, 288),
|
||||
# (1024, 576),
|
||||
# (1024, 576)]
|
||||
# assert [data_sample.flip for data_sample in tta_results['data_sample']
|
||||
# ] == [False, True, False, True, False, True]
|
||||
tta_transform = dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in [0.5, 1.0, 2.0]
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='mmseg.PackSegInputs')]
|
||||
])
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256),
|
||||
(3, 144, 256),
|
||||
(3, 288, 512),
|
||||
(3, 288, 512),
|
||||
(3, 576, 1024),
|
||||
(3, 576, 1024)]
|
||||
assert [
|
||||
data_sample.metainfo['flip']
|
||||
for data_sample in tta_results['data_samples']
|
||||
] == [False, True, False, True, False, True]
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmengine import ConfigDict
|
||||
from mmengine.model import BaseTTAModel
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import register_all_modules
|
||||
from .utils import * # noqa: F401,F403
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
def test_encoder_decoder_tta():
|
||||
|
||||
segmentor_cfg = ConfigDict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=dict(type='ExampleDecodeHead'),
|
||||
train_cfg=None,
|
||||
test_cfg=dict(mode='whole'))
|
||||
|
||||
cfg = ConfigDict(type='SegTTAModel', module=segmentor_cfg)
|
||||
|
||||
model: BaseTTAModel = MODELS.build(cfg)
|
||||
|
||||
imgs = []
|
||||
data_samples = []
|
||||
directions = ['horizontal', 'vertical']
|
||||
for i in range(12):
|
||||
flip_direction = directions[0] if i % 3 == 0 else directions[1]
|
||||
imgs.append(torch.randn(1, 3, 10 + i, 10 + i))
|
||||
data_samples.append([
|
||||
SegDataSample(
|
||||
metainfo=dict(
|
||||
ori_shape=(10, 10),
|
||||
img_shape=(10 + i, 10 + i),
|
||||
flip=(i % 2 == 0),
|
||||
flip_direction=flip_direction),
|
||||
gt_sem_seg=PixelData(data=torch.randint(0, 19, (1, 10, 10))))
|
||||
])
|
||||
|
||||
model.test_step(dict(inputs=imgs, data_samples=data_samples))
|
||||
|
||||
# test out_channels == 1
|
||||
segmentor_cfg = ConfigDict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=dict(
|
||||
type='ExampleDecodeHead',
|
||||
num_classes=2,
|
||||
out_channels=1,
|
||||
threshold=0.4),
|
||||
train_cfg=None,
|
||||
test_cfg=dict(mode='whole'))
|
||||
model.module = MODELS.build(segmentor_cfg)
|
||||
for data_sample in data_samples:
|
||||
data_sample[0].gt_sem_seg.data = torch.randint(0, 2, (1, 10, 10))
|
||||
model.test_step(dict(inputs=imgs, data_samples=data_samples))
|
|
@ -66,9 +66,9 @@ class ExampleBackbone(nn.Module):
|
|||
@MODELS.register_module()
|
||||
class ExampleDecodeHead(BaseDecodeHead):
|
||||
|
||||
def __init__(self, num_classes=19, out_channels=None):
|
||||
def __init__(self, num_classes=19, out_channels=None, **kwargs):
|
||||
super().__init__(
|
||||
3, 3, num_classes=num_classes, out_channels=out_channels)
|
||||
3, 3, num_classes=num_classes, out_channels=out_channels, **kwargs)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.cls_seg(inputs[0])
|
||||
|
|
|
@ -43,6 +43,8 @@ def parse_args():
|
|||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument(
|
||||
'--tta', action='store_true', help='Test time augmentation')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
|
@ -99,6 +101,11 @@ def main():
|
|||
if args.show or args.show_dir:
|
||||
cfg = trigger_visualization_hook(cfg, args)
|
||||
|
||||
if args.tta:
|
||||
cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
|
||||
cfg.tta_model.module = cfg.model
|
||||
cfg.model = cfg.tta_model
|
||||
|
||||
# build the runner from config
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
|
|
Loading…
Reference in New Issue