Merge branch 'limengzhang/refactor_PackSegInputs' into 'refactor_dev'
[Refactor] Add PackSegInputs and some Transforms See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!10pull/1801/head
commit
02f92276ae
|
@ -1,19 +1,21 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.transforms import (LoadImageFromFile, MultiScaleFlipAug, Normalize,
|
||||
Pad, RandomChoiceResize, RandomFlip, RandomResize,
|
||||
Resize)
|
||||
|
||||
from .compose import Compose
|
||||
from .formatting import (Collect, ImageToTensor, ToDataContainer, ToTensor,
|
||||
Transpose, to_tensor)
|
||||
from .loading import LoadAnnotations, LoadImageFromFile
|
||||
from .test_time_aug import MultiScaleFlipAug
|
||||
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomFlip, RandomMosaic, RandomRotate, Rerange,
|
||||
Resize, RGB2Gray, SegRescale)
|
||||
from .formatting import (ImageToTensor, PackSegInputs, ToDataContainer,
|
||||
Transpose)
|
||||
from .loading import LoadAnnotations
|
||||
from .transforms import (CLAHE, AdjustGamma, PhotoMetricDistortion, RandomCrop,
|
||||
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
|
||||
RGB2Gray, SegRescale)
|
||||
|
||||
__all__ = [
|
||||
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
|
||||
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
|
||||
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
|
||||
'Compose', 'ImageToTensor', 'ToDataContainer', 'Transpose',
|
||||
'LoadAnnotations', 'LoadImageFromFile', 'RandomFlip', 'Pad', 'RandomCrop',
|
||||
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
|
||||
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut',
|
||||
'RandomMosaic'
|
||||
'RandomMosaic', 'PackSegInputs', 'Resize', 'RandomResize',
|
||||
'RandomChoiceResize', 'MultiScaleFlipAug'
|
||||
]
|
||||
|
|
|
@ -1,67 +1,95 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from collections.abc import Sequence
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.parallel import DataContainer as DC
|
||||
from mmcv.transforms import to_tensor
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
from mmengine.data import PixelData
|
||||
|
||||
from mmseg.core import SegDataSample
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
||||
|
||||
def to_tensor(data):
|
||||
"""Convert objects of various python types to :obj:`torch.Tensor`.
|
||||
|
||||
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
|
||||
:class:`Sequence`, :class:`int` and :class:`float`.
|
||||
|
||||
Args:
|
||||
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
|
||||
be converted.
|
||||
"""
|
||||
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data
|
||||
elif isinstance(data, np.ndarray):
|
||||
return torch.from_numpy(data)
|
||||
elif isinstance(data, Sequence) and not mmcv.is_str(data):
|
||||
return torch.tensor(data)
|
||||
elif isinstance(data, int):
|
||||
return torch.LongTensor([data])
|
||||
elif isinstance(data, float):
|
||||
return torch.FloatTensor([data])
|
||||
else:
|
||||
raise TypeError(f'type {type(data)} cannot be converted to tensor.')
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ToTensor(object):
|
||||
"""Convert some results to :obj:`torch.Tensor` by given keys.
|
||||
class PackSegInputs(BaseTransform):
|
||||
"""Pack the inputs data for the semantic segmentation.
|
||||
|
||||
The ``img_meta`` item is always populated. The contents of the
|
||||
``img_meta`` dictionary depends on ``meta_keys``. By default this includes:
|
||||
|
||||
- ``filename``: filename of the image
|
||||
|
||||
- ``ori_filename``: original filename of the image file
|
||||
|
||||
- ``ori_shape``: original shape of the image as a tuple (h, w, c)
|
||||
|
||||
- ``img_shape``: shape of the image input to the network as a tuple \
|
||||
(h, w, c). Note that images may be zero padded on the \
|
||||
bottom/right if the batch tensor is larger than this shape.
|
||||
|
||||
- ``pad_shape``: shape of padded images
|
||||
|
||||
- ``scale_factor``: a float indicating the preprocessing scale
|
||||
|
||||
- ``flip``: a boolean indicating if image flip transform was used
|
||||
|
||||
- ``flip_direction``: the flipping direction
|
||||
|
||||
- ``img_norm_cfg``: config of image pixel normalization
|
||||
|
||||
Args:
|
||||
keys (Sequence[str]): Keys that need to be converted to Tensor.
|
||||
meta_keys (Sequence[str], optional): Meta keys to be packed from
|
||||
``SegDataSample`` and collected in ``data[img_metas]``.
|
||||
Default: ``('filename', 'ori_filename', 'ori_shape',
|
||||
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction', 'img_norm_cfg')``
|
||||
"""
|
||||
|
||||
def __init__(self, keys):
|
||||
self.keys = keys
|
||||
def __init__(self,
|
||||
meta_keys=('filename', 'ori_filename', 'ori_shape',
|
||||
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction', 'img_norm_cfg')):
|
||||
self.meta_keys = meta_keys
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to convert data in results to :obj:`torch.Tensor`.
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Method to pack the input data.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict contains the data to convert.
|
||||
results (dict): Result dict from the data pipeline.
|
||||
|
||||
Returns:
|
||||
dict: The result dict contains the data converted
|
||||
to :obj:`torch.Tensor`.
|
||||
dict:
|
||||
|
||||
- 'inputs' (obj:`torch.Tensor`): The forward data of models.
|
||||
- 'data_sample' (obj:`SegDataSample`): The annotation info of the
|
||||
sample.
|
||||
"""
|
||||
packed_results = dict()
|
||||
if 'img' in results:
|
||||
img = results['img']
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
img = np.ascontiguousarray(img.transpose(2, 0, 1))
|
||||
packed_results['inputs'] = to_tensor(img)
|
||||
|
||||
for key in self.keys:
|
||||
results[key] = to_tensor(results[key])
|
||||
return results
|
||||
data_sample = SegDataSample()
|
||||
if 'gt_seg_map' in results:
|
||||
gt_sem_seg_data = dict(
|
||||
data=results['gt_seg_map'][None, ...].astype(np.int64))
|
||||
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(keys={self.keys})'
|
||||
img_meta = {}
|
||||
for key in self.meta_keys:
|
||||
if key in results:
|
||||
img_meta[key] = results[key]
|
||||
data_sample.set_metainfo(img_meta)
|
||||
packed_results['data_sample'] = data_sample
|
||||
|
||||
return packed_results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(meta_keys={self.meta_keys})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
|
@ -173,117 +201,3 @@ class ToDataContainer(object):
|
|||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(fields={self.fields})'
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class DefaultFormatBundle(object):
|
||||
"""Default formatting bundle.
|
||||
|
||||
It simplifies the pipeline of formatting common fields, including "img"
|
||||
and "gt_semantic_seg". These fields are formatted as follows.
|
||||
|
||||
- img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
|
||||
- gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor,
|
||||
(3)to DataContainer (stack=True)
|
||||
"""
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to transform and format common fields in results.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict contains the data to convert.
|
||||
|
||||
Returns:
|
||||
dict: The result dict contains the data that is formatted with
|
||||
default bundle.
|
||||
"""
|
||||
|
||||
if 'img' in results:
|
||||
img = results['img']
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
img = np.ascontiguousarray(img.transpose(2, 0, 1))
|
||||
results['img'] = DC(to_tensor(img), stack=True)
|
||||
if 'gt_semantic_seg' in results:
|
||||
# convert to long
|
||||
results['gt_semantic_seg'] = DC(
|
||||
to_tensor(results['gt_semantic_seg'][None,
|
||||
...].astype(np.int64)),
|
||||
stack=True)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class Collect(object):
|
||||
"""Collect data from the loader relevant to the specific task.
|
||||
|
||||
This is usually the last stage of the data loader pipeline. Typically keys
|
||||
is set to some subset of "img", "gt_semantic_seg".
|
||||
|
||||
The "img_meta" item is always populated. The contents of the "img_meta"
|
||||
dictionary depends on "meta_keys". By default this includes:
|
||||
|
||||
- "img_shape": shape of the image input to the network as a tuple
|
||||
(h, w, c). Note that images may be zero padded on the bottom/right
|
||||
if the batch tensor is larger than this shape.
|
||||
|
||||
- "scale_factor": a float indicating the preprocessing scale
|
||||
|
||||
- "flip": a boolean indicating if image flip transform was used
|
||||
|
||||
- "filename": path to the image file
|
||||
|
||||
- "ori_shape": original shape of the image as a tuple (h, w, c)
|
||||
|
||||
- "pad_shape": image shape after padding
|
||||
|
||||
- "img_norm_cfg": a dict of normalization information:
|
||||
- mean - per channel mean subtraction
|
||||
- std - per channel std divisor
|
||||
- to_rgb - bool indicating if bgr was converted to rgb
|
||||
|
||||
Args:
|
||||
keys (Sequence[str]): Keys of results to be collected in ``data``.
|
||||
meta_keys (Sequence[str], optional): Meta keys to be converted to
|
||||
``mmcv.DataContainer`` and collected in ``data[img_metas]``.
|
||||
Default: (``filename``, ``ori_filename``, ``ori_shape``,
|
||||
``img_shape``, ``pad_shape``, ``scale_factor``, ``flip``,
|
||||
``flip_direction``, ``img_norm_cfg``)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
keys,
|
||||
meta_keys=('filename', 'ori_filename', 'ori_shape',
|
||||
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction', 'img_norm_cfg')):
|
||||
self.keys = keys
|
||||
self.meta_keys = meta_keys
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to collect keys in results. The keys in ``meta_keys``
|
||||
will be converted to :obj:mmcv.DataContainer.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict contains the data to collect.
|
||||
|
||||
Returns:
|
||||
dict: The result dict contains the following keys
|
||||
- keys in``self.keys``
|
||||
- ``img_metas``
|
||||
"""
|
||||
|
||||
data = {}
|
||||
img_meta = {}
|
||||
for key in self.meta_keys:
|
||||
img_meta[key] = results[key]
|
||||
data['img_metas'] = DC(img_meta, cpu_only=True)
|
||||
for key in self.keys:
|
||||
data[key] = results[key]
|
||||
return data
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + \
|
||||
f'(keys={self.keys}, meta_keys={self.meta_keys})'
|
||||
|
|
|
@ -1,138 +1,89 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
|
||||
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageFromFile(object):
|
||||
"""Load an image from file.
|
||||
class LoadAnnotations(MMCV_LoadAnnotations):
|
||||
"""Load annotations for semantic segmentation provided by dataset.
|
||||
|
||||
Required keys are "img_prefix" and "img_info" (a dict that must contain the
|
||||
key "filename"). Added or updated keys are "filename", "img", "img_shape",
|
||||
"ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
|
||||
"scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
|
||||
The annotation format is as the following:
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an uint8 array.
|
||||
Defaults to False.
|
||||
color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
|
||||
Defaults to 'color'.
|
||||
file_client_args (dict): Arguments to instantiate a FileClient.
|
||||
See :class:`mmcv.fileio.FileClient` for details.
|
||||
Defaults to ``dict(backend='disk')``.
|
||||
imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
|
||||
'cv2'
|
||||
"""
|
||||
.. code-block:: python
|
||||
|
||||
def __init__(self,
|
||||
to_float32=False,
|
||||
color_type='color',
|
||||
file_client_args=dict(backend='disk'),
|
||||
imdecode_backend='cv2'):
|
||||
self.to_float32 = to_float32
|
||||
self.color_type = color_type
|
||||
self.file_client_args = file_client_args.copy()
|
||||
self.file_client = None
|
||||
self.imdecode_backend = imdecode_backend
|
||||
{
|
||||
# Filename of semantic segmentation ground truth file.
|
||||
'seg_map_path': 'a/b/c'
|
||||
}
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call functions to load image and get image meta information.
|
||||
After this module, the annotation has been changed to the format below:
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:`mmseg.CustomDataset`.
|
||||
.. code-block:: python
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
{
|
||||
# in str
|
||||
'seg_fields': List
|
||||
# In uint8 type.
|
||||
'gt_seg_map': np.ndarray (H, W)
|
||||
}
|
||||
|
||||
if self.file_client is None:
|
||||
self.file_client = mmcv.FileClient(**self.file_client_args)
|
||||
Required Keys:
|
||||
|
||||
if results.get('img_prefix') is not None:
|
||||
filename = osp.join(results['img_prefix'],
|
||||
results['img_info']['filename'])
|
||||
else:
|
||||
filename = results['img_info']['filename']
|
||||
img_bytes = self.file_client.get(filename)
|
||||
img = mmcv.imfrombytes(
|
||||
img_bytes, flag=self.color_type, backend=self.imdecode_backend)
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
- seg_map_path (str): Path of semantic segmentation ground truth file.
|
||||
|
||||
results['filename'] = filename
|
||||
results['ori_filename'] = results['img_info']['filename']
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
num_channels = 1 if len(img.shape) < 3 else img.shape[2]
|
||||
results['img_norm_cfg'] = dict(
|
||||
mean=np.zeros(num_channels, dtype=np.float32),
|
||||
std=np.ones(num_channels, dtype=np.float32),
|
||||
to_rgb=False)
|
||||
return results
|
||||
Added Keys:
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(to_float32={self.to_float32},'
|
||||
repr_str += f"color_type='{self.color_type}',"
|
||||
repr_str += f"imdecode_backend='{self.imdecode_backend}')"
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadAnnotations(object):
|
||||
"""Load annotations for semantic segmentation.
|
||||
- seg_fields (List)
|
||||
- gt_seg_map (np.uint8)
|
||||
|
||||
Args:
|
||||
reduce_zero_label (bool): Whether reduce all label value by 1.
|
||||
Usually used for datasets where 0 is background label.
|
||||
Default: False.
|
||||
Defaults to False.
|
||||
imdecode_backend (str): The image decoding backend type. The backend
|
||||
argument for :func:``mmcv.imfrombytes``.
|
||||
See :fun:``mmcv.imfrombytes`` for details.
|
||||
Defaults to 'pillow'.
|
||||
file_client_args (dict): Arguments to instantiate a FileClient.
|
||||
See :class:`mmcv.fileio.FileClient` for details.
|
||||
See :class:``mmcv.fileio.FileClient`` for details.
|
||||
Defaults to ``dict(backend='disk')``.
|
||||
imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
|
||||
'pillow'
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reduce_zero_label=False,
|
||||
file_client_args=dict(backend='disk'),
|
||||
imdecode_backend='pillow'):
|
||||
def __init__(
|
||||
self,
|
||||
reduce_zero_label=False,
|
||||
file_client_args=dict(backend='disk'),
|
||||
imdecode_backend='pillow',
|
||||
) -> None:
|
||||
super().__init__(
|
||||
with_bbox=False,
|
||||
with_label=False,
|
||||
with_seg=True,
|
||||
with_keypoints=False,
|
||||
imdecode_backend=imdecode_backend,
|
||||
file_client_args=file_client_args)
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
self.file_client_args = file_client_args.copy()
|
||||
self.file_client = None
|
||||
self.imdecode_backend = imdecode_backend
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to load multiple types annotations.
|
||||
def _load_seg_map(self, results: dict) -> None:
|
||||
"""Private function to load semantic segmentation annotations.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:`mmseg.CustomDataset`.
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded semantic segmentation annotations.
|
||||
"""
|
||||
|
||||
if self.file_client is None:
|
||||
self.file_client = mmcv.FileClient(**self.file_client_args)
|
||||
|
||||
if results.get('seg_prefix', None) is not None:
|
||||
filename = osp.join(results['seg_prefix'],
|
||||
results['ann_info']['seg_map'])
|
||||
else:
|
||||
filename = results['ann_info']['seg_map']
|
||||
img_bytes = self.file_client.get(filename)
|
||||
img_bytes = self.file_client.get(results['seg_map_path'])
|
||||
gt_semantic_seg = mmcv.imfrombytes(
|
||||
img_bytes, flag='unchanged',
|
||||
backend=self.imdecode_backend).squeeze().astype(np.uint8)
|
||||
|
||||
# modify if custom classes
|
||||
if results.get('label_map', None) is not None:
|
||||
# Add deep copy to solve bug of repeatedly
|
||||
|
@ -147,12 +98,12 @@ class LoadAnnotations(object):
|
|||
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
||||
gt_semantic_seg = gt_semantic_seg - 1
|
||||
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
||||
results['gt_semantic_seg'] = gt_semantic_seg
|
||||
results['seg_fields'].append('gt_semantic_seg')
|
||||
return results
|
||||
results['gt_seg_map'] = gt_semantic_seg
|
||||
results['seg_fields'].append('gt_seg_map')
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(reduce_zero_label={self.reduce_zero_label},'
|
||||
repr_str += f"imdecode_backend='{self.imdecode_backend}')"
|
||||
repr_str += f'file_client_args={self.file_client_args})'
|
||||
return repr_str
|
||||
|
|
|
@ -1,134 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
|
||||
from mmseg.registry import TRANSFORMS
|
||||
from .compose import Compose
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class MultiScaleFlipAug(object):
|
||||
"""Test-time augmentation with multiple scales and flipping.
|
||||
|
||||
An example configuration is as followed:
|
||||
|
||||
.. code-block::
|
||||
|
||||
img_scale=(2048, 1024),
|
||||
img_ratios=[0.5, 1.0],
|
||||
flip=True,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size_divisor=32),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
]
|
||||
|
||||
After MultiScaleFLipAug with above configuration, the results are wrapped
|
||||
into lists of the same length as followed:
|
||||
|
||||
.. code-block::
|
||||
|
||||
dict(
|
||||
img=[...],
|
||||
img_shape=[...],
|
||||
scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)]
|
||||
flip=[False, True, False, True]
|
||||
...
|
||||
)
|
||||
|
||||
Args:
|
||||
transforms (list[dict]): Transforms to apply in each augmentation.
|
||||
img_scale (None | tuple | list[tuple]): Images scales for resizing.
|
||||
img_ratios (float | list[float]): Image ratios for resizing
|
||||
flip (bool): Whether apply flip augmentation. Default: False.
|
||||
flip_direction (str | list[str]): Flip augmentation directions,
|
||||
options are "horizontal" and "vertical". If flip_direction is list,
|
||||
multiple flip augmentations will be applied.
|
||||
It has no effect when flip == False. Default: "horizontal".
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
transforms,
|
||||
img_scale,
|
||||
img_ratios=None,
|
||||
flip=False,
|
||||
flip_direction='horizontal'):
|
||||
self.transforms = Compose(transforms)
|
||||
if img_ratios is not None:
|
||||
img_ratios = img_ratios if isinstance(img_ratios,
|
||||
list) else [img_ratios]
|
||||
assert mmcv.is_list_of(img_ratios, float)
|
||||
if img_scale is None:
|
||||
# mode 1: given img_scale=None and a range of image ratio
|
||||
self.img_scale = None
|
||||
assert mmcv.is_list_of(img_ratios, float)
|
||||
elif isinstance(img_scale, tuple) and mmcv.is_list_of(
|
||||
img_ratios, float):
|
||||
assert len(img_scale) == 2
|
||||
# mode 2: given a scale and a range of image ratio
|
||||
self.img_scale = [(int(img_scale[0] * ratio),
|
||||
int(img_scale[1] * ratio))
|
||||
for ratio in img_ratios]
|
||||
else:
|
||||
# mode 3: given multiple scales
|
||||
self.img_scale = img_scale if isinstance(img_scale,
|
||||
list) else [img_scale]
|
||||
assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None
|
||||
self.flip = flip
|
||||
self.img_ratios = img_ratios
|
||||
self.flip_direction = flip_direction if isinstance(
|
||||
flip_direction, list) else [flip_direction]
|
||||
assert mmcv.is_list_of(self.flip_direction, str)
|
||||
if not self.flip and self.flip_direction != ['horizontal']:
|
||||
warnings.warn(
|
||||
'flip_direction has no effect when flip is set to False')
|
||||
if (self.flip
|
||||
and not any([t['type'] == 'RandomFlip' for t in transforms])):
|
||||
warnings.warn(
|
||||
'flip has no effect when RandomFlip is not in transforms')
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to apply test time augment transforms on results.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict contains the data to transform.
|
||||
|
||||
Returns:
|
||||
dict[str: list]: The augmented data, where each value is wrapped
|
||||
into a list.
|
||||
"""
|
||||
|
||||
aug_data = []
|
||||
if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float):
|
||||
h, w = results['img'].shape[:2]
|
||||
img_scale = [(int(w * ratio), int(h * ratio))
|
||||
for ratio in self.img_ratios]
|
||||
else:
|
||||
img_scale = self.img_scale
|
||||
flip_aug = [False, True] if self.flip else [False]
|
||||
for scale in img_scale:
|
||||
for flip in flip_aug:
|
||||
for direction in self.flip_direction:
|
||||
_results = results.copy()
|
||||
_results['scale'] = scale
|
||||
_results['flip'] = flip
|
||||
_results['flip_direction'] = direction
|
||||
data = self.transforms(_results)
|
||||
aug_data.append(data)
|
||||
# list of dict to dict of list
|
||||
aug_data_dict = {key: [] for key in aug_data[0]}
|
||||
for data in aug_data:
|
||||
for key, val in data.items():
|
||||
aug_data_dict[key].append(val)
|
||||
return aug_data_dict
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(transforms={self.transforms}, '
|
||||
repr_str += f'img_scale={self.img_scale}, flip={self.flip})'
|
||||
repr_str += f'flip_direction={self.flip_direction}'
|
||||
return repr_str
|
|
@ -6,7 +6,7 @@ import mmcv
|
|||
import numpy as np
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
from mmcv.transforms.utils import cache_randomness
|
||||
from mmcv.utils import deprecated_api_warning, is_tuple_of
|
||||
from mmcv.utils import is_tuple_of
|
||||
from numpy import random
|
||||
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
@ -69,429 +69,6 @@ class ResizeToMultiple(object):
|
|||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class Resize(object):
|
||||
"""Resize images & seg.
|
||||
|
||||
This transform resizes the input image to some scale. If the input dict
|
||||
contains the key "scale", then the scale in the input dict is used,
|
||||
otherwise the specified scale in the init method is used.
|
||||
|
||||
``img_scale`` can be None, a tuple (single-scale) or a list of tuple
|
||||
(multi-scale). There are 4 multiscale modes:
|
||||
|
||||
- ``ratio_range is not None``:
|
||||
1. When img_scale is None, img_scale is the shape of image in results
|
||||
(img_scale = results['img'].shape[:2]) and the image is resized based
|
||||
on the original size. (mode 1)
|
||||
2. When img_scale is a tuple (single-scale), randomly sample a ratio from
|
||||
the ratio range and multiply it with the image scale. (mode 2)
|
||||
|
||||
- ``ratio_range is None and multiscale_mode == "range"``: randomly sample a
|
||||
scale from the a range. (mode 3)
|
||||
|
||||
- ``ratio_range is None and multiscale_mode == "value"``: randomly sample a
|
||||
scale from multiple scales. (mode 4)
|
||||
|
||||
Args:
|
||||
img_scale (tuple or list[tuple]): Images scales for resizing.
|
||||
Default:None.
|
||||
multiscale_mode (str): Either "range" or "value".
|
||||
Default: 'range'
|
||||
ratio_range (tuple[float]): (min_ratio, max_ratio).
|
||||
Default: None
|
||||
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
|
||||
image. Default: True
|
||||
min_size (int, optional): The minimum size for input and the shape
|
||||
of the image and seg map will not be less than ``min_size``.
|
||||
As the shape of model input is fixed like 'SETR' and 'BEiT'.
|
||||
Following the setting in these models, resized images must be
|
||||
bigger than the crop size in ``slide_inference``. Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_scale=None,
|
||||
multiscale_mode='range',
|
||||
ratio_range=None,
|
||||
keep_ratio=True,
|
||||
min_size=None):
|
||||
if img_scale is None:
|
||||
self.img_scale = None
|
||||
else:
|
||||
if isinstance(img_scale, list):
|
||||
self.img_scale = img_scale
|
||||
else:
|
||||
self.img_scale = [img_scale]
|
||||
assert mmcv.is_list_of(self.img_scale, tuple)
|
||||
|
||||
if ratio_range is not None:
|
||||
# mode 1: given img_scale=None and a range of image ratio
|
||||
# mode 2: given a scale and a range of image ratio
|
||||
assert self.img_scale is None or len(self.img_scale) == 1
|
||||
else:
|
||||
# mode 3 and 4: given multiple scales or a range of scales
|
||||
assert multiscale_mode in ['value', 'range']
|
||||
|
||||
self.multiscale_mode = multiscale_mode
|
||||
self.ratio_range = ratio_range
|
||||
self.keep_ratio = keep_ratio
|
||||
self.min_size = min_size
|
||||
|
||||
@staticmethod
|
||||
def random_select(img_scales):
|
||||
"""Randomly select an img_scale from given candidates.
|
||||
|
||||
Args:
|
||||
img_scales (list[tuple]): Images scales for selection.
|
||||
|
||||
Returns:
|
||||
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
|
||||
where ``img_scale`` is the selected image scale and
|
||||
``scale_idx`` is the selected index in the given candidates.
|
||||
"""
|
||||
|
||||
assert mmcv.is_list_of(img_scales, tuple)
|
||||
scale_idx = np.random.randint(len(img_scales))
|
||||
img_scale = img_scales[scale_idx]
|
||||
return img_scale, scale_idx
|
||||
|
||||
@staticmethod
|
||||
def random_sample(img_scales):
|
||||
"""Randomly sample an img_scale when ``multiscale_mode=='range'``.
|
||||
|
||||
Args:
|
||||
img_scales (list[tuple]): Images scale range for sampling.
|
||||
There must be two tuples in img_scales, which specify the lower
|
||||
and upper bound of image scales.
|
||||
|
||||
Returns:
|
||||
(tuple, None): Returns a tuple ``(img_scale, None)``, where
|
||||
``img_scale`` is sampled scale and None is just a placeholder
|
||||
to be consistent with :func:`random_select`.
|
||||
"""
|
||||
|
||||
assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
|
||||
img_scale_long = [max(s) for s in img_scales]
|
||||
img_scale_short = [min(s) for s in img_scales]
|
||||
long_edge = np.random.randint(
|
||||
min(img_scale_long),
|
||||
max(img_scale_long) + 1)
|
||||
short_edge = np.random.randint(
|
||||
min(img_scale_short),
|
||||
max(img_scale_short) + 1)
|
||||
img_scale = (long_edge, short_edge)
|
||||
return img_scale, None
|
||||
|
||||
@staticmethod
|
||||
def random_sample_ratio(img_scale, ratio_range):
|
||||
"""Randomly sample an img_scale when ``ratio_range`` is specified.
|
||||
|
||||
A ratio will be randomly sampled from the range specified by
|
||||
``ratio_range``. Then it would be multiplied with ``img_scale`` to
|
||||
generate sampled scale.
|
||||
|
||||
Args:
|
||||
img_scale (tuple): Images scale base to multiply with ratio.
|
||||
ratio_range (tuple[float]): The minimum and maximum ratio to scale
|
||||
the ``img_scale``.
|
||||
|
||||
Returns:
|
||||
(tuple, None): Returns a tuple ``(scale, None)``, where
|
||||
``scale`` is sampled ratio multiplied with ``img_scale`` and
|
||||
None is just a placeholder to be consistent with
|
||||
:func:`random_select`.
|
||||
"""
|
||||
|
||||
assert isinstance(img_scale, tuple) and len(img_scale) == 2
|
||||
min_ratio, max_ratio = ratio_range
|
||||
assert min_ratio <= max_ratio
|
||||
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
|
||||
scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
|
||||
return scale, None
|
||||
|
||||
def _random_scale(self, results):
|
||||
"""Randomly sample an img_scale according to ``ratio_range`` and
|
||||
``multiscale_mode``.
|
||||
|
||||
If ``ratio_range`` is specified, a ratio will be sampled and be
|
||||
multiplied with ``img_scale``.
|
||||
If multiple scales are specified by ``img_scale``, a scale will be
|
||||
sampled according to ``multiscale_mode``.
|
||||
Otherwise, single scale will be used.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:`dataset`.
|
||||
|
||||
Returns:
|
||||
dict: Two new keys 'scale` and 'scale_idx` are added into
|
||||
``results``, which would be used by subsequent pipelines.
|
||||
"""
|
||||
|
||||
if self.ratio_range is not None:
|
||||
if self.img_scale is None:
|
||||
h, w = results['img'].shape[:2]
|
||||
scale, scale_idx = self.random_sample_ratio((w, h),
|
||||
self.ratio_range)
|
||||
else:
|
||||
scale, scale_idx = self.random_sample_ratio(
|
||||
self.img_scale[0], self.ratio_range)
|
||||
elif len(self.img_scale) == 1:
|
||||
scale, scale_idx = self.img_scale[0], 0
|
||||
elif self.multiscale_mode == 'range':
|
||||
scale, scale_idx = self.random_sample(self.img_scale)
|
||||
elif self.multiscale_mode == 'value':
|
||||
scale, scale_idx = self.random_select(self.img_scale)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
results['scale'] = scale
|
||||
results['scale_idx'] = scale_idx
|
||||
|
||||
def _resize_img(self, results):
|
||||
"""Resize images with ``results['scale']``."""
|
||||
if self.keep_ratio:
|
||||
if self.min_size is not None:
|
||||
# TODO: Now 'min_size' is an 'int' which means the minimum
|
||||
# shape of images is (min_size, min_size, 3). 'min_size'
|
||||
# with tuple type will be supported, i.e. the width and
|
||||
# height are not equal.
|
||||
if min(results['scale']) < self.min_size:
|
||||
new_short = self.min_size
|
||||
else:
|
||||
new_short = min(results['scale'])
|
||||
|
||||
h, w = results['img'].shape[:2]
|
||||
if h > w:
|
||||
new_h, new_w = new_short * h / w, new_short
|
||||
else:
|
||||
new_h, new_w = new_short, new_short * w / h
|
||||
results['scale'] = (new_h, new_w)
|
||||
|
||||
img, scale_factor = mmcv.imrescale(
|
||||
results['img'], results['scale'], return_scale=True)
|
||||
# the w_scale and h_scale has minor difference
|
||||
# a real fix should be done in the mmcv.imrescale in the future
|
||||
new_h, new_w = img.shape[:2]
|
||||
h, w = results['img'].shape[:2]
|
||||
w_scale = new_w / w
|
||||
h_scale = new_h / h
|
||||
else:
|
||||
img, w_scale, h_scale = mmcv.imresize(
|
||||
results['img'], results['scale'], return_scale=True)
|
||||
scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
|
||||
dtype=np.float32)
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['pad_shape'] = img.shape # in case that there is no padding
|
||||
results['scale_factor'] = scale_factor
|
||||
results['keep_ratio'] = self.keep_ratio
|
||||
|
||||
def _resize_seg(self, results):
|
||||
"""Resize semantic segmentation map with ``results['scale']``."""
|
||||
for key in results.get('seg_fields', []):
|
||||
if self.keep_ratio:
|
||||
gt_seg = mmcv.imrescale(
|
||||
results[key], results['scale'], interpolation='nearest')
|
||||
else:
|
||||
gt_seg = mmcv.imresize(
|
||||
results[key], results['scale'], interpolation='nearest')
|
||||
results[key] = gt_seg
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to resize images, bounding boxes, masks, semantic
|
||||
segmentation map.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',
|
||||
'keep_ratio' keys are added into result dict.
|
||||
"""
|
||||
|
||||
if 'scale' not in results:
|
||||
self._random_scale(results)
|
||||
self._resize_img(results)
|
||||
self._resize_seg(results)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += (f'(img_scale={self.img_scale}, '
|
||||
f'multiscale_mode={self.multiscale_mode}, '
|
||||
f'ratio_range={self.ratio_range}, '
|
||||
f'keep_ratio={self.keep_ratio})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomFlip(object):
|
||||
"""Flip the image & seg.
|
||||
|
||||
If the input dict contains the key "flip", then the flag will be used,
|
||||
otherwise it will be randomly decided by a ratio specified in the init
|
||||
method.
|
||||
|
||||
Args:
|
||||
prob (float, optional): The flipping probability. Default: None.
|
||||
direction(str, optional): The flipping direction. Options are
|
||||
'horizontal' and 'vertical'. Default: 'horizontal'.
|
||||
"""
|
||||
|
||||
@deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip')
|
||||
def __init__(self, prob=None, direction='horizontal'):
|
||||
self.prob = prob
|
||||
self.direction = direction
|
||||
if prob is not None:
|
||||
assert prob >= 0 and prob <= 1
|
||||
assert direction in ['horizontal', 'vertical']
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to flip bounding boxes, masks, semantic segmentation
|
||||
maps.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Flipped results, 'flip', 'flip_direction' keys are added into
|
||||
result dict.
|
||||
"""
|
||||
|
||||
if 'flip' not in results:
|
||||
flip = True if np.random.rand() < self.prob else False
|
||||
results['flip'] = flip
|
||||
if 'flip_direction' not in results:
|
||||
results['flip_direction'] = self.direction
|
||||
if results['flip']:
|
||||
# flip image
|
||||
results['img'] = mmcv.imflip(
|
||||
results['img'], direction=results['flip_direction'])
|
||||
|
||||
# flip segs
|
||||
for key in results.get('seg_fields', []):
|
||||
# use copy() to make numpy stride positive
|
||||
results[key] = mmcv.imflip(
|
||||
results[key], direction=results['flip_direction']).copy()
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(prob={self.prob})'
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class Pad(object):
|
||||
"""Pad the image & mask.
|
||||
|
||||
There are two padding modes: (1) pad to a fixed size and (2) pad to the
|
||||
minimum size that is divisible by some number.
|
||||
Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
|
||||
|
||||
Args:
|
||||
size (tuple, optional): Fixed padding size.
|
||||
size_divisor (int, optional): The divisor of padded size.
|
||||
pad_val (float, optional): Padding value. Default: 0.
|
||||
seg_pad_val (float, optional): Padding value of segmentation map.
|
||||
Default: 255.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
size=None,
|
||||
size_divisor=None,
|
||||
pad_val=0,
|
||||
seg_pad_val=255):
|
||||
self.size = size
|
||||
self.size_divisor = size_divisor
|
||||
self.pad_val = pad_val
|
||||
self.seg_pad_val = seg_pad_val
|
||||
# only one of size and size_divisor should be valid
|
||||
assert size is not None or size_divisor is not None
|
||||
assert size is None or size_divisor is None
|
||||
|
||||
def _pad_img(self, results):
|
||||
"""Pad images according to ``self.size``."""
|
||||
if self.size is not None:
|
||||
padded_img = mmcv.impad(
|
||||
results['img'], shape=self.size, pad_val=self.pad_val)
|
||||
elif self.size_divisor is not None:
|
||||
padded_img = mmcv.impad_to_multiple(
|
||||
results['img'], self.size_divisor, pad_val=self.pad_val)
|
||||
results['img'] = padded_img
|
||||
results['pad_shape'] = padded_img.shape
|
||||
results['pad_fixed_size'] = self.size
|
||||
results['pad_size_divisor'] = self.size_divisor
|
||||
|
||||
def _pad_seg(self, results):
|
||||
"""Pad masks according to ``results['pad_shape']``."""
|
||||
for key in results.get('seg_fields', []):
|
||||
results[key] = mmcv.impad(
|
||||
results[key],
|
||||
shape=results['pad_shape'][:2],
|
||||
pad_val=self.seg_pad_val)
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to pad images, masks, semantic segmentation maps.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Updated result dict.
|
||||
"""
|
||||
|
||||
self._pad_img(results)
|
||||
self._pad_seg(results)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(size={self.size}, size_divisor={self.size_divisor}, ' \
|
||||
f'pad_val={self.pad_val})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class Normalize(object):
|
||||
"""Normalize the image.
|
||||
|
||||
Added key is "img_norm_cfg".
|
||||
|
||||
Args:
|
||||
mean (sequence): Mean values of 3 channels.
|
||||
std (sequence): Std values of 3 channels.
|
||||
to_rgb (bool): Whether to convert the image from BGR to RGB,
|
||||
default is true.
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std, to_rgb=True):
|
||||
self.mean = np.array(mean, dtype=np.float32)
|
||||
self.std = np.array(std, dtype=np.float32)
|
||||
self.to_rgb = to_rgb
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to normalize images.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Normalized results, 'img_norm_cfg' key is added into
|
||||
result dict.
|
||||
"""
|
||||
|
||||
results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std,
|
||||
self.to_rgb)
|
||||
results['img_norm_cfg'] = dict(
|
||||
mean=self.mean, std=self.std, to_rgb=self.to_rgb)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \
|
||||
f'{self.to_rgb})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class Rerange(object):
|
||||
"""Rerange the image pixel value.
|
||||
|
|
|
@ -1,162 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import glob
|
||||
import os
|
||||
from os.path import dirname, exists, isdir, join, relpath
|
||||
|
||||
from mmcv import Config
|
||||
from torch import nn
|
||||
|
||||
from mmseg.models import build_segmentor
|
||||
|
||||
|
||||
def _get_config_directory():
|
||||
"""Find the predefined segmentor config directory."""
|
||||
try:
|
||||
# Assume we are running in the source mmsegmentation repo
|
||||
repo_dpath = dirname(dirname(__file__))
|
||||
except NameError:
|
||||
# For IPython development when this __file__ is not defined
|
||||
import mmseg
|
||||
repo_dpath = dirname(dirname(mmseg.__file__))
|
||||
config_dpath = join(repo_dpath, 'configs')
|
||||
if not exists(config_dpath):
|
||||
raise Exception('Cannot find config path')
|
||||
return config_dpath
|
||||
|
||||
|
||||
def test_config_build_segmentor():
|
||||
"""Test that all segmentation models defined in the configs can be
|
||||
initialized."""
|
||||
config_dpath = _get_config_directory()
|
||||
print('Found config_dpath = {!r}'.format(config_dpath))
|
||||
|
||||
config_fpaths = []
|
||||
# one config each sub folder
|
||||
for sub_folder in os.listdir(config_dpath):
|
||||
if isdir(sub_folder):
|
||||
config_fpaths.append(
|
||||
list(glob.glob(join(config_dpath, sub_folder, '*.py')))[0])
|
||||
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
|
||||
config_names = [relpath(p, config_dpath) for p in config_fpaths]
|
||||
|
||||
print('Using {} config files'.format(len(config_names)))
|
||||
|
||||
for config_fname in config_names:
|
||||
config_fpath = join(config_dpath, config_fname)
|
||||
config_mod = Config.fromfile(config_fpath)
|
||||
|
||||
config_mod.model
|
||||
print('Building segmentor, config_fpath = {!r}'.format(config_fpath))
|
||||
|
||||
# Remove pretrained keys to allow for testing in an offline environment
|
||||
if 'pretrained' in config_mod.model:
|
||||
config_mod.model['pretrained'] = None
|
||||
|
||||
print('building {}'.format(config_fname))
|
||||
segmentor = build_segmentor(config_mod.model)
|
||||
assert segmentor is not None
|
||||
|
||||
head_config = config_mod.model['decode_head']
|
||||
_check_decode_head(head_config, segmentor.decode_head)
|
||||
|
||||
|
||||
def test_config_data_pipeline():
|
||||
"""Test whether the data pipeline is valid and can process corner cases.
|
||||
|
||||
CommandLine:
|
||||
xdoctest -m tests/test_config.py test_config_build_data_pipeline
|
||||
"""
|
||||
import numpy as np
|
||||
from mmcv import Config
|
||||
|
||||
from mmseg.datasets.pipelines import Compose
|
||||
|
||||
config_dpath = _get_config_directory()
|
||||
print('Found config_dpath = {!r}'.format(config_dpath))
|
||||
|
||||
import glob
|
||||
config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
|
||||
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
|
||||
config_names = [relpath(p, config_dpath) for p in config_fpaths]
|
||||
|
||||
print('Using {} config files'.format(len(config_names)))
|
||||
|
||||
for config_fname in config_names:
|
||||
config_fpath = join(config_dpath, config_fname)
|
||||
print(
|
||||
'Building data pipeline, config_fpath = {!r}'.format(config_fpath))
|
||||
config_mod = Config.fromfile(config_fpath)
|
||||
|
||||
# remove loading pipeline
|
||||
load_img_pipeline = config_mod.train_pipeline.pop(0)
|
||||
to_float32 = load_img_pipeline.get('to_float32', False)
|
||||
config_mod.train_pipeline.pop(0)
|
||||
config_mod.test_pipeline.pop(0)
|
||||
|
||||
train_pipeline = Compose(config_mod.train_pipeline)
|
||||
test_pipeline = Compose(config_mod.test_pipeline)
|
||||
|
||||
img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
|
||||
if to_float32:
|
||||
img = img.astype(np.float32)
|
||||
seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
|
||||
|
||||
results = dict(
|
||||
filename='test_img.png',
|
||||
ori_filename='test_img.png',
|
||||
img=img,
|
||||
img_shape=img.shape,
|
||||
ori_shape=img.shape,
|
||||
gt_semantic_seg=seg)
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
|
||||
print('Test training data pipeline: \n{!r}'.format(train_pipeline))
|
||||
output_results = train_pipeline(results)
|
||||
assert output_results is not None
|
||||
|
||||
results = dict(
|
||||
filename='test_img.png',
|
||||
ori_filename='test_img.png',
|
||||
img=img,
|
||||
img_shape=img.shape,
|
||||
ori_shape=img.shape,
|
||||
)
|
||||
print('Test testing data pipeline: \n{!r}'.format(test_pipeline))
|
||||
output_results = test_pipeline(results)
|
||||
assert output_results is not None
|
||||
|
||||
|
||||
def _check_decode_head(decode_head_cfg, decode_head):
|
||||
if isinstance(decode_head_cfg, list):
|
||||
assert isinstance(decode_head, nn.ModuleList)
|
||||
assert len(decode_head_cfg) == len(decode_head)
|
||||
num_heads = len(decode_head)
|
||||
for i in range(num_heads):
|
||||
_check_decode_head(decode_head_cfg[i], decode_head[i])
|
||||
return
|
||||
# check consistency between head_config and roi_head
|
||||
assert decode_head_cfg['type'] == decode_head.__class__.__name__
|
||||
|
||||
assert decode_head_cfg['type'] == decode_head.__class__.__name__
|
||||
|
||||
in_channels = decode_head_cfg.in_channels
|
||||
input_transform = decode_head.input_transform
|
||||
assert input_transform in ['resize_concat', 'multiple_select', None]
|
||||
if input_transform is not None:
|
||||
assert isinstance(in_channels, (list, tuple))
|
||||
assert isinstance(decode_head.in_index, (list, tuple))
|
||||
assert len(in_channels) == len(decode_head.in_index)
|
||||
elif input_transform == 'resize_concat':
|
||||
assert sum(in_channels) == decode_head.in_channels
|
||||
else:
|
||||
assert isinstance(in_channels, int)
|
||||
assert in_channels == decode_head.in_channels
|
||||
assert isinstance(decode_head.in_index, int)
|
||||
|
||||
if decode_head_cfg['type'] == 'PointHead':
|
||||
assert decode_head_cfg.channels+decode_head_cfg.num_classes == \
|
||||
decode_head.fc_seg.in_channels
|
||||
assert decode_head.fc_seg.out_channels == decode_head_cfg.num_classes
|
||||
else:
|
||||
assert decode_head_cfg.channels == decode_head.conv_seg.in_channels
|
||||
assert decode_head.conv_seg.out_channels == decode_head_cfg.num_classes
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from mmengine.data import BaseDataElement
|
||||
|
||||
from mmseg.core import SegDataSample
|
||||
from mmseg.datasets.pipelines import PackSegInputs
|
||||
|
||||
|
||||
class TestPackSegInputs(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Setup the model and optimizer which are used in every test method.
|
||||
|
||||
TestCase calls functions in this order: setUp() -> testMethod() ->
|
||||
tearDown() -> cleanUp()
|
||||
"""
|
||||
data_prefix = osp.join(osp.dirname(__file__), '../../data')
|
||||
img_path = osp.join(data_prefix, 'color.jpg')
|
||||
rng = np.random.RandomState(0)
|
||||
self.results = {
|
||||
'filename': img_path,
|
||||
'ori_filename': 'color.jpg',
|
||||
'ori_shape': (300, 400),
|
||||
'pad_shape': (600, 800),
|
||||
'img_shape': (600, 800),
|
||||
'scale_factor': 2.0,
|
||||
'flip': False,
|
||||
'flip_direction': 'horizontal',
|
||||
'img_norm_cfg': None,
|
||||
'img': rng.rand(300, 400),
|
||||
'gt_seg_map': rng.rand(300, 400),
|
||||
}
|
||||
self.meta_keys = ('filename', 'ori_filename', 'ori_shape', 'img_shape',
|
||||
'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction', 'img_norm_cfg')
|
||||
|
||||
def test_transform(self):
|
||||
transform = PackSegInputs(meta_keys=self.meta_keys)
|
||||
results = transform(copy.deepcopy(self.results))
|
||||
self.assertIn('data_sample', results)
|
||||
self.assertIsInstance(results['data_sample'], SegDataSample)
|
||||
self.assertIsInstance(results['data_sample'].gt_sem_seg,
|
||||
BaseDataElement)
|
||||
self.assertEqual(results['data_sample'].ori_shape,
|
||||
results['data_sample'].gt_sem_seg.shape)
|
||||
|
||||
def test_repr(self):
|
||||
transform = PackSegInputs(meta_keys=self.meta_keys)
|
||||
self.assertEqual(
|
||||
repr(transform), f'PackSegInputs(meta_keys={self.meta_keys})')
|
|
@ -0,0 +1,156 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.transforms import LoadImageFromFile
|
||||
|
||||
from mmseg.datasets.pipelines import LoadAnnotations
|
||||
|
||||
|
||||
class TestLoading(object):
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
cls.data_prefix = osp.join(osp.dirname(__file__), '../data')
|
||||
|
||||
def test_load_img(self):
|
||||
results = dict(img_path=osp.join(self.data_prefix, 'color.jpg'))
|
||||
transform = LoadImageFromFile()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['img_path'] == osp.join(self.data_prefix, 'color.jpg')
|
||||
assert results['img'].shape == (288, 512, 3)
|
||||
assert results['img'].dtype == np.uint8
|
||||
assert results['ori_shape'] == results['img'].shape[:2]
|
||||
assert repr(transform) == transform.__class__.__name__ + \
|
||||
"(to_float32=False, color_type='color'," + \
|
||||
" imdecode_backend='cv2', file_client_args={'backend': 'disk'})"
|
||||
|
||||
# to_float32
|
||||
transform = LoadImageFromFile(to_float32=True)
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['img'].dtype == np.float32
|
||||
|
||||
# gray image
|
||||
results = dict(img_path=osp.join(self.data_prefix, 'gray.jpg'))
|
||||
transform = LoadImageFromFile()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['img'].shape == (288, 512, 3)
|
||||
assert results['img'].dtype == np.uint8
|
||||
|
||||
transform = LoadImageFromFile(color_type='unchanged')
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['img'].shape == (288, 512)
|
||||
assert results['img'].dtype == np.uint8
|
||||
|
||||
def test_load_seg(self):
|
||||
seg_path = osp.join(self.data_prefix, 'seg.png')
|
||||
results = dict(seg_map_path=seg_path, seg_fields=[])
|
||||
transform = LoadAnnotations()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['gt_seg_map'].shape == (288, 512)
|
||||
assert results['gt_seg_map'].dtype == np.uint8
|
||||
assert repr(transform) == transform.__class__.__name__ + \
|
||||
"(reduce_zero_label=False,imdecode_backend='pillow')" + \
|
||||
"file_client_args={'backend': 'disk'})"
|
||||
|
||||
# reduce_zero_label
|
||||
transform = LoadAnnotations(reduce_zero_label=True)
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['gt_seg_map'].shape == (288, 512)
|
||||
assert results['gt_seg_map'].dtype == np.uint8
|
||||
|
||||
def test_load_seg_custom_classes(self):
|
||||
|
||||
test_img = np.random.rand(10, 10)
|
||||
test_gt = np.zeros_like(test_img)
|
||||
test_gt[2:4, 2:4] = 1
|
||||
test_gt[2:4, 6:8] = 2
|
||||
test_gt[6:8, 2:4] = 3
|
||||
test_gt[6:8, 6:8] = 4
|
||||
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
img_path = osp.join(tmp_dir.name, 'img.jpg')
|
||||
gt_path = osp.join(tmp_dir.name, 'gt.png')
|
||||
|
||||
mmcv.imwrite(test_img, img_path)
|
||||
mmcv.imwrite(test_gt, gt_path)
|
||||
|
||||
# test only train with label with id 3
|
||||
results = dict(
|
||||
img_path=img_path,
|
||||
seg_map_path=gt_path,
|
||||
label_map={
|
||||
0: 0,
|
||||
1: 0,
|
||||
2: 0,
|
||||
3: 1,
|
||||
4: 0
|
||||
},
|
||||
seg_fields=[])
|
||||
|
||||
load_imgs = LoadImageFromFile()
|
||||
results = load_imgs(copy.deepcopy(results))
|
||||
|
||||
load_anns = LoadAnnotations()
|
||||
results = load_anns(copy.deepcopy(results))
|
||||
|
||||
gt_array = results['gt_seg_map']
|
||||
|
||||
true_mask = np.zeros_like(gt_array)
|
||||
true_mask[6:8, 2:4] = 1
|
||||
|
||||
assert results['seg_fields'] == ['gt_seg_map']
|
||||
assert gt_array.shape == (10, 10)
|
||||
assert gt_array.dtype == np.uint8
|
||||
np.testing.assert_array_equal(gt_array, true_mask)
|
||||
|
||||
# test only train with label with id 4 and 3
|
||||
results = dict(
|
||||
img_path=osp.join(self.data_prefix, 'color.jpg'),
|
||||
seg_map_path=gt_path,
|
||||
label_map={
|
||||
0: 0,
|
||||
1: 0,
|
||||
2: 0,
|
||||
3: 2,
|
||||
4: 1
|
||||
},
|
||||
seg_fields=[])
|
||||
|
||||
load_imgs = LoadImageFromFile()
|
||||
results = load_imgs(copy.deepcopy(results))
|
||||
|
||||
load_anns = LoadAnnotations()
|
||||
results = load_anns(copy.deepcopy(results))
|
||||
|
||||
gt_array = results['gt_seg_map']
|
||||
|
||||
true_mask = np.zeros_like(gt_array)
|
||||
true_mask[6:8, 2:4] = 2
|
||||
true_mask[6:8, 6:8] = 1
|
||||
|
||||
assert results['seg_fields'] == ['gt_seg_map']
|
||||
assert gt_array.shape == (10, 10)
|
||||
assert gt_array.dtype == np.uint8
|
||||
np.testing.assert_array_equal(gt_array, true_mask)
|
||||
|
||||
# test no custom classes
|
||||
results = dict(img_path=img_path, seg_map_path=gt_path, seg_fields=[])
|
||||
|
||||
load_imgs = LoadImageFromFile()
|
||||
results = load_imgs(copy.deepcopy(results))
|
||||
|
||||
load_anns = LoadAnnotations()
|
||||
results = load_anns(copy.deepcopy(results))
|
||||
|
||||
gt_array = results['gt_seg_map']
|
||||
|
||||
assert results['seg_fields'] == ['gt_seg_map']
|
||||
assert gt_array.shape == (10, 10)
|
||||
assert gt_array.dtype == np.uint8
|
||||
np.testing.assert_array_equal(gt_array, test_gt)
|
||||
|
||||
tmp_dir.cleanup()
|
|
@ -0,0 +1,201 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
||||
|
||||
def test_resize():
|
||||
# Test `Resize`, `RandomResize` and `RandomChoiceResize` from
|
||||
# MMCV transform. Noted: `RandomResize` has args `scales` but
|
||||
# `Resize` and `RandomResize` has args `scale`.
|
||||
transform = dict(type='Resize', scale=(1333, 800), keep_ratio=True)
|
||||
resize_module = TRANSFORMS.build(transform)
|
||||
|
||||
results = dict()
|
||||
# (288, 512, 3)
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
resized_results = resize_module(results.copy())
|
||||
# img_shape = results['img'].shape[:2] in ``MMCV resize`` function
|
||||
# so right now it is (750, 1333) rather than (750, 1333, 3)
|
||||
assert resized_results['img_shape'] == (750, 1333)
|
||||
|
||||
# test keep_ratio=False
|
||||
transform = dict(
|
||||
type='RandomResize',
|
||||
scale=(1280, 800),
|
||||
ratio_range=(1.0, 1.0),
|
||||
resize_cfg=dict(type='Resize', keep_ratio=False))
|
||||
resize_module = TRANSFORMS.build(transform)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert resized_results['img_shape'] == (800, 1280)
|
||||
|
||||
# test `RandomChoiceResize`, which in older mmsegmentation
|
||||
# `Resize` is multiscale_mode='range'
|
||||
transform = dict(type='RandomResize', scale=[(1333, 400), (1333, 1200)])
|
||||
resize_module = TRANSFORMS.build(transform)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert max(resized_results['img_shape'][:2]) <= 1333
|
||||
assert min(resized_results['img_shape'][:2]) >= 400
|
||||
assert min(resized_results['img_shape'][:2]) <= 1200
|
||||
|
||||
# test RandomChoiceResize, which in older mmsegmentation
|
||||
# `Resize` is multiscale_mode='value'
|
||||
transform = dict(
|
||||
type='RandomChoiceResize',
|
||||
scales=[(1333, 800), (1333, 400)],
|
||||
resize_cfg=dict(type='Resize', keep_ratio=True))
|
||||
resize_module = TRANSFORMS.build(transform)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert resized_results['img_shape'] in [(750, 1333), (400, 711)]
|
||||
|
||||
transform = dict(type='Resize', scale_factor=(0.9, 1.1), keep_ratio=True)
|
||||
resize_module = TRANSFORMS.build(transform)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert max(resized_results['img_shape'][:2]) <= 1333 * 1.1
|
||||
|
||||
# test scale=None and scale_factor is tuple.
|
||||
# img shape: (288, 512, 3)
|
||||
transform = dict(
|
||||
type='Resize', scale=None, scale_factor=(0.5, 2.0), keep_ratio=True)
|
||||
resize_module = TRANSFORMS.build(transform)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert int(288 * 0.5) <= resized_results['img_shape'][0] <= 288 * 2.0
|
||||
assert int(512 * 0.5) <= resized_results['img_shape'][1] <= 512 * 2.0
|
||||
|
||||
# test minimum resized image shape is 640
|
||||
transform = dict(type='Resize', scale=(2560, 640), keep_ratio=True)
|
||||
resize_module = TRANSFORMS.build(transform)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert resized_results['img_shape'] == (640, 1138)
|
||||
|
||||
# test minimum resized image shape is 640 when img_scale=(512, 640)
|
||||
# where should define `scale_factor` in MMCV new ``Resize`` function.
|
||||
min_size_ratio = max(640 / img.shape[0], 640 / img.shape[1])
|
||||
transform = dict(
|
||||
type='Resize', scale_factor=min_size_ratio, keep_ratio=True)
|
||||
resize_module = TRANSFORMS.build(transform)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert resized_results['img_shape'] == (640, 1138)
|
||||
|
||||
# test h > w
|
||||
img = np.random.randn(512, 288, 3)
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
min_size_ratio = max(640 / img.shape[0], 640 / img.shape[1])
|
||||
transform = dict(
|
||||
type='Resize',
|
||||
scale=(2560, 640),
|
||||
scale_factor=min_size_ratio,
|
||||
keep_ratio=True)
|
||||
resize_module = TRANSFORMS.build(transform)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert resized_results['img_shape'] == (1138, 640)
|
||||
|
||||
|
||||
def test_flip():
|
||||
# test assertion for invalid prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomFlip', prob=1.5)
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
# test assertion for invalid direction
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomFlip', prob=1.0, direction='horizonta')
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
transform = dict(type='RandomFlip', prob=1.0)
|
||||
flip_module = TRANSFORMS.build(transform)
|
||||
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
original_img = copy.deepcopy(img)
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
original_seg = copy.deepcopy(seg)
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = flip_module(results)
|
||||
|
||||
flip_module = TRANSFORMS.build(transform)
|
||||
results = flip_module(results)
|
||||
assert np.equal(original_img, results['img']).all()
|
||||
assert np.equal(original_seg, results['gt_semantic_seg']).all()
|
||||
|
||||
|
||||
def test_pad():
|
||||
# test assertion if both size_divisor and size is None
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Pad')
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
transform = dict(type='Pad', size_divisor=32)
|
||||
transform = TRANSFORMS.build(transform)
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
original_img = copy.deepcopy(img)
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
# original img already divisible by 32
|
||||
assert np.equal(results['img'], original_img).all()
|
||||
img_shape = results['img'].shape
|
||||
assert img_shape[0] % 32 == 0
|
||||
assert img_shape[1] % 32 == 0
|
||||
|
||||
|
||||
def test_normalize():
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True)
|
||||
transform = dict(type='Normalize', **img_norm_cfg)
|
||||
transform = TRANSFORMS.build(transform)
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
original_img = copy.deepcopy(img)
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
|
||||
mean = np.array(img_norm_cfg['mean'])
|
||||
std = np.array(img_norm_cfg['std'])
|
||||
converted_img = (original_img[..., ::-1] - mean) / std
|
||||
assert np.allclose(results['img'], converted_img)
|
|
@ -0,0 +1,151 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import pytest
|
||||
|
||||
from mmseg.datasets.pipelines import * # noqa
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
||||
|
||||
def test_multi_scale_flip_aug():
|
||||
# test assertion if scales=None, scale_factor=1 (not float).
|
||||
with pytest.raises(AssertionError):
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
scales=None,
|
||||
scale_factor=1,
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
TRANSFORMS.build(tta_transform)
|
||||
|
||||
# test assertion if scales=None, scale_factor=None.
|
||||
with pytest.raises(AssertionError):
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
scales=None,
|
||||
scale_factor=None,
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
TRANSFORMS.build(tta_transform)
|
||||
|
||||
# test assertion if scales=(512, 512), scale_factor=1 (not float).
|
||||
with pytest.raises(AssertionError):
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
scales=(512, 512),
|
||||
scale_factor=1,
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
TRANSFORMS.build(tta_transform)
|
||||
meta_keys = ('img', 'ori_shape', 'ori_height', 'ori_width', 'pad_shape',
|
||||
'scale_factor', 'scale', 'flip')
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
scales=[(256, 256), (512, 512), (1024, 1024)],
|
||||
allow_flip=False,
|
||||
resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
)
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
|
||||
results = dict()
|
||||
# (288, 512, 3)
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
results['img'] = img
|
||||
results['ori_shape'] = img.shape
|
||||
results['ori_height'] = img.shape[0]
|
||||
results['ori_width'] = img.shape[1]
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [data_sample.scale
|
||||
for data_sample in tta_results['data_sample']] == [(256, 256),
|
||||
(512, 512),
|
||||
(1024, 1024)]
|
||||
assert [data_sample.flip for data_sample in tta_results['data_sample']
|
||||
] == [False, False, False]
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
scales=[(256, 256), (512, 512), (1024, 1024)],
|
||||
allow_flip=True,
|
||||
resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
)
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [data_sample.scale
|
||||
for data_sample in tta_results['data_sample']] == [(256, 256),
|
||||
(256, 256),
|
||||
(512, 512),
|
||||
(512, 512),
|
||||
(1024, 1024),
|
||||
(1024, 1024)]
|
||||
assert [data_sample.flip for data_sample in tta_results['data_sample']
|
||||
] == [False, True, False, True, False, True]
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
scales=[(512, 512)],
|
||||
allow_flip=False,
|
||||
resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
)
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [tta_results['data_sample'][0].scale] == [(512, 512)]
|
||||
assert [tta_results['data_sample'][0].flip] == [False]
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
scales=[(512, 512)],
|
||||
allow_flip=True,
|
||||
resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
)
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [data_sample.scale
|
||||
for data_sample in tta_results['data_sample']] == [(512, 512),
|
||||
(512, 512)]
|
||||
assert [data_sample.flip
|
||||
for data_sample in tta_results['data_sample']] == [False, True]
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
scale_factor=[0.5, 1.0, 2.0],
|
||||
allow_flip=False,
|
||||
resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
)
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [data_sample.scale
|
||||
for data_sample in tta_results['data_sample']] == [(256, 144),
|
||||
(512, 288),
|
||||
(1024, 576)]
|
||||
assert [data_sample.flip for data_sample in tta_results['data_sample']
|
||||
] == [False, False, False]
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
scale_factor=[0.5, 1.0, 2.0],
|
||||
allow_flip=True,
|
||||
resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
)
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [data_sample.scale
|
||||
for data_sample in tta_results['data_sample']] == [(256, 144),
|
||||
(256, 144),
|
||||
(512, 288),
|
||||
(512, 288),
|
||||
(1024, 576),
|
||||
(1024, 576)]
|
||||
assert [data_sample.flip for data_sample in tta_results['data_sample']
|
||||
] == [False, True, False, True, False, True]
|
Loading…
Reference in New Issue