mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
[Feature] Add base transform interface (#1538)
* Support deepcopy for Config (#1658) * Support deepcopy for Config * Iterate the `__dict__` of Config directly. * Use __new__ to avoid unnecessary initialization. * Improve according to comments * [Feature] Add spconv ops from mmdet3d (#1581) * add ops (spconv) of mmdet3d * fix typo * refactor code * resolve comments in #1452 * fix compile error * fix bugs * fix bug * transform from 'types.h' to 'extension.h' * fix bug * transform from 'types.h' to 'extension.h' in parrots * add extension.h in pybind.cpp * add unittest * Recover code * (1) Remove prettyprint.h (2) Switch `T` to `scalar_t` (3) Remove useless lines (4) Refine example in docstring of sparse_modules.py * (1) rename from `cu.h` to `cuh` (2) remove useless files (3) move cpu files to `pytorch/cpu` * reorganize files * Add docstring for sparse_functional.py * use dispatcher * remove template * use dispatch in cuda ops * resolve Segmentation fault * remove useless files * fix lint * fix lint * fix lint * fix unittest in test_build_layers.py * add tensorview into include_dirs when compiling * recover all deleted files * fix lint and comments * recover setup.py * replace tv::GPU as tv::TorchGPU & support device guard * fix lint Co-authored-by: hdc <hudingchang.vendor@sensetime.com> Co-authored-by: grimoire <yaoqian@sensetime.com> * Imporve the docstring of imfrombytes and fix a deprecation-warning (#1731) * [Refactor] Refactor the interface for RoIAlignRotated (#1662) * fix interface for RoIAlignRotated * Add a unit test for RoIAlignRotated * Make a unit test for RoIAlignRotated concise * fix interface for RoIAlignRotated * Refactor ext_module.nms_rotated * Lint cpp files * add transforms * add invoking time check for cacheable methods * fix lint * add unittest * fix bug in non-strict input mapping * fix ci * fix ci * fix compatibility with python<3.9 * fix typing compatibility * fix import * fix typing * add alternative for nullcontext * fix import * fix import * add docstrings * add docstrings * fix callable check * resolve comments * fix lint * enrich unittest cases * fix lint * fix unittest Co-authored-by: Ma Zerun <mzr1996@163.com> Co-authored-by: Wenhao Wu <79644370+wHao-Wu@users.noreply.github.com> Co-authored-by: hdc <hudingchang.vendor@sensetime.com> Co-authored-by: grimoire <yaoqian@sensetime.com> Co-authored-by: Jiazhen Wang <47851024+teamwong111@users.noreply.github.com> Co-authored-by: Hakjin Lee <nijkah@gmail.com>
This commit is contained in:
parent
8b47579e7c
commit
d00b0cec74
@ -47,3 +47,8 @@ ops
|
||||
------
|
||||
.. automodule:: mmcv.ops
|
||||
:members:
|
||||
|
||||
transform
|
||||
---------
|
||||
.. automodule:: mmcv.transform
|
||||
:members:
|
||||
|
@ -47,3 +47,8 @@ ops
|
||||
------
|
||||
.. automodule:: mmcv.ops
|
||||
:members:
|
||||
|
||||
transform
|
||||
---------
|
||||
.. automodule:: mmcv.transform
|
||||
:members:
|
||||
|
@ -3,6 +3,7 @@
|
||||
from .arraymisc import *
|
||||
from .fileio import *
|
||||
from .image import *
|
||||
from .transform import *
|
||||
from .utils import *
|
||||
from .version import *
|
||||
from .video import *
|
||||
|
5
mmcv/transform/__init__.py
Normal file
5
mmcv/transform/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .builder import TRANSFORMS
|
||||
from .wrappers import ApplyToMultiple, Compose, RandomChoice, Remap
|
||||
|
||||
__all__ = ['TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap']
|
27
mmcv/transform/base.py
Normal file
27
mmcv/transform/base.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class BaseTransform(metaclass=ABCMeta):
|
||||
|
||||
def __call__(self, results: Dict) -> Dict:
|
||||
|
||||
return self.transform(results)
|
||||
|
||||
@abstractmethod
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""The transform function. All subclass of BaseTransform should
|
||||
override this method.
|
||||
|
||||
This function takes the result dict as the input, and can add new
|
||||
items to the dict or modify existing items in the dict. And the result
|
||||
dict will be returned in the end, which allows to concate multiple
|
||||
transforms into a pipeline.
|
||||
|
||||
Args:
|
||||
results (dict): The result dict.
|
||||
|
||||
Returns:
|
||||
dict: The result dict.
|
||||
"""
|
4
mmcv/transform/builder.py
Normal file
4
mmcv/transform/builder.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from ..utils.registry import Registry
|
||||
|
||||
TRANSFORMS = Registry('transform')
|
162
mmcv/transform/utils.py
Normal file
162
mmcv/transform/utils.py
Normal file
@ -0,0 +1,162 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Union
|
||||
|
||||
from .base import BaseTransform
|
||||
|
||||
|
||||
class cacheable_method:
|
||||
"""Decorator that marks a method of a transform class as a cacheable
|
||||
method.
|
||||
|
||||
This decorator is usually used together with the context-manager
|
||||
:func`:cache_random_params`. In this context, a cacheable method will
|
||||
cache its return value(s) at the first time of being invoked, and always
|
||||
return the cached values when being invoked again.
|
||||
|
||||
.. note::
|
||||
Only a instance method can be decorated as a cacheable_method.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
|
||||
# Check `func` is to be bound as an instance method
|
||||
if not inspect.isfunction(func):
|
||||
raise TypeError('Unsupport callable to decorate with'
|
||||
'@cacheable_method.')
|
||||
func_args = inspect.getfullargspec(func).args
|
||||
if len(func_args) == 0 or func_args[0] != 'self':
|
||||
raise TypeError(
|
||||
'@cacheable_method should only be used to decorate '
|
||||
'instance methods (the first argument is `self`).')
|
||||
|
||||
functools.update_wrapper(self, func)
|
||||
self.func = func
|
||||
self.instance_ref = None
|
||||
|
||||
def __set_name__(self, owner, name):
|
||||
# Maintain a record of decorated methods in the class
|
||||
if not hasattr(owner, '_cacheable_methods'):
|
||||
setattr(owner, '_cacheable_methods', [])
|
||||
owner._cacheable_methods.append(self.__name__)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# Get the transform instance whose method is decorated
|
||||
# by cacheable_method
|
||||
instance = self.instance_ref()
|
||||
name = self.__name__
|
||||
|
||||
# Check the flag `self._cache_enabled`, which should be
|
||||
# set by the contextmanagers like `cache_random_parameters`
|
||||
cache_enabled = getattr(instance, '_cache_enabled', False)
|
||||
|
||||
if cache_enabled:
|
||||
# Initialize the cache of the transform instances. The flag
|
||||
# `cache_enabled` is set by contextmanagers like
|
||||
# `cache_random_params`.
|
||||
if not hasattr(instance, '_cache'):
|
||||
setattr(instance, '_cache', {})
|
||||
|
||||
if name not in instance._cache:
|
||||
instance._cache[name] = self.func(instance, *args, **kwargs)
|
||||
# Return the cached value
|
||||
return instance._cache[name]
|
||||
else:
|
||||
# Clear cache
|
||||
if hasattr(instance, '_cache'):
|
||||
del instance._cache
|
||||
# Return function output
|
||||
return self.func(instance, *args, **kwargs)
|
||||
|
||||
def __get__(self, obj, cls):
|
||||
self.instance_ref = weakref.ref(obj)
|
||||
return self
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cache_random_params(transforms: Union[BaseTransform, Iterable]):
|
||||
"""Context-manager that enables the cache of cacheable methods in
|
||||
transforms.
|
||||
|
||||
In this mode, cacheable methods will cache their return values on the
|
||||
first invoking, and always return the cached value afterward. This allow
|
||||
to apply random transforms in a deterministic way. For example, apply same
|
||||
transforms on multiple examples. See `cacheable_method` for more
|
||||
information.
|
||||
|
||||
Args:
|
||||
transforms (BaseTransform|list[BaseTransform]): The transforms to
|
||||
enable cache.
|
||||
"""
|
||||
|
||||
# key2method stores the original methods that are replaced by the wrapped
|
||||
# ones. These methods will be restituted when exiting the context.
|
||||
key2method = dict()
|
||||
|
||||
# key2counter stores the usage number of each cacheable_method. This is
|
||||
# used to check that any cacheable_method is invoked once during processing
|
||||
# on data sample.
|
||||
key2counter = defaultdict(int)
|
||||
|
||||
def _add_counter(obj, method_name):
|
||||
method = getattr(obj, method_name)
|
||||
key = f'{id(obj)}.{method_name}'
|
||||
key2method[key] = method
|
||||
|
||||
@functools.wraps(method)
|
||||
def wrapped(*args, **kwargs):
|
||||
key2counter[key] += 1
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
def _start_cache(t: BaseTransform):
|
||||
# Set cache enabled flag
|
||||
setattr(t, '_cache_enabled', True)
|
||||
|
||||
# Store the original method and init the counter
|
||||
if hasattr(t, '_cacheable_methods'):
|
||||
setattr(t, 'transform', _add_counter(t, 'transform'))
|
||||
for name in t._cacheable_methods:
|
||||
setattr(t, name, _add_counter(t, name))
|
||||
|
||||
def _end_cache(t: BaseTransform):
|
||||
# Remove cache enabled flag
|
||||
del t._cache_enabled
|
||||
if hasattr(t, '_cache'):
|
||||
del t._cache
|
||||
|
||||
# Restore the original method
|
||||
if hasattr(t, '_cacheable_methods'):
|
||||
key_transform = f'{id(t)}.transform'
|
||||
for name in t._cacheable_methods:
|
||||
key = f'{id(t)}.{name}'
|
||||
if key2counter[key] != key2counter[key_transform]:
|
||||
raise RuntimeError(
|
||||
'The cacheable method should be called once and only'
|
||||
f'once during processing one data sample. {t} got'
|
||||
f'unmatched number of {key2counter[key]} ({name}) vs'
|
||||
f'{key2counter[key_transform]} (data samples)')
|
||||
setattr(t, name, key2method[key])
|
||||
setattr(t, 'transform', key2method[key_transform])
|
||||
|
||||
def _apply(t: Union[BaseTransform, Iterable],
|
||||
func: Callable[[BaseTransform], None]):
|
||||
if isinstance(t, BaseTransform):
|
||||
if hasattr(t, '_cacheable_methods'):
|
||||
func(t)
|
||||
if isinstance(t, Iterable):
|
||||
for _t in t:
|
||||
_apply(_t, func)
|
||||
|
||||
try:
|
||||
_apply(transforms, _start_cache)
|
||||
yield
|
||||
finally:
|
||||
_apply(transforms, _end_cache)
|
457
mmcv/transform/wrappers.py
Normal file
457
mmcv/transform/wrappers.py
Normal file
@ -0,0 +1,457 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mmcv
|
||||
from .base import BaseTransform
|
||||
from .builder import TRANSFORMS
|
||||
from .utils import cache_random_params
|
||||
|
||||
# Indicator for required but missing keys in results
|
||||
NotInResults = object()
|
||||
|
||||
# Import nullcontext if python>=3.7, otherwise use a simple alternative
|
||||
# implementation.
|
||||
try:
|
||||
from contextlib import nullcontext
|
||||
except ImportError:
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def nullcontext(resource=None):
|
||||
try:
|
||||
yield resource
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
class Compose(BaseTransform):
|
||||
"""Compose multiple transforms sequentially.
|
||||
|
||||
Args:
|
||||
transforms (list[dict | callable]): Sequence of transform object or
|
||||
config dict to be composed.
|
||||
|
||||
Examples:
|
||||
>>> pipeline = [
|
||||
>>> dict(type='Compose',
|
||||
>>> transforms=[
|
||||
>>> dict(type='LoadImageFromFile'),
|
||||
>>> dict(type='Normalize')
|
||||
>>> ]
|
||||
>>> )
|
||||
>>> ]
|
||||
"""
|
||||
|
||||
def __init__(self, transforms: List[Union[Dict, Callable[[Dict], Dict]]]):
|
||||
assert isinstance(transforms, Sequence)
|
||||
self.transforms = []
|
||||
for transform in transforms:
|
||||
if isinstance(transform, dict):
|
||||
transform = TRANSFORMS.build(transform)
|
||||
self.transforms.append(transform)
|
||||
elif callable(transform):
|
||||
self.transforms.append(transform)
|
||||
else:
|
||||
raise TypeError('transform must be callable or a dict, but got'
|
||||
f' {type(transform)}')
|
||||
|
||||
def __iter__(self):
|
||||
"""Allow easy iteration over the transform sequence."""
|
||||
return iter(self.transforms)
|
||||
|
||||
def transform(self, results: Dict) -> Optional[Dict]:
|
||||
"""Call function to apply transforms sequentially.
|
||||
|
||||
Args:
|
||||
results (dict): A result dict contains the results to transform.
|
||||
|
||||
Returns:
|
||||
dict or None: Transformed results.
|
||||
"""
|
||||
for t in self.transforms:
|
||||
results = t(results)
|
||||
if results is None:
|
||||
return None
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
"""Compute the string representation."""
|
||||
format_string = self.__class__.__name__ + '('
|
||||
for t in self.transforms:
|
||||
format_string += f'\n {t}'
|
||||
format_string += '\n)'
|
||||
return format_string
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class Remap(BaseTransform):
|
||||
"""A transform wrapper to remap and reorganize the input/output of the
|
||||
wrapped transforms (or sub-pipeline).
|
||||
|
||||
Args:
|
||||
transforms (list[dict | callable]): Sequence of transform object or
|
||||
config dict to be wrapped.
|
||||
input_mapping (dict): A dict that defines the input key mapping.
|
||||
The keys corresponds to the inner key (i.e., kwargs of the
|
||||
`transform` method), and should be string type. The values
|
||||
corresponds to the outer keys (i.e., the keys of the
|
||||
data/results), and should have a type of string, list or dict.
|
||||
None means not applying input mapping. Default: None.
|
||||
output_mapping (dict): A dict that defines the output key mapping.
|
||||
The keys and values have the same meanings and rules as in the
|
||||
`input_mapping`. Default: None.
|
||||
inplace (bool): If True, an inverse of the input_mapping will be used
|
||||
as the output_mapping. Note that if inplace is set True,
|
||||
output_mapping should be None and strict should be True.
|
||||
Default: False.
|
||||
strict (bool): If True, the outer keys in the input_mapping must exist
|
||||
in the input data, or an exception will be raised. If False,
|
||||
the missing keys will be assigned a special value `NotInResults`
|
||||
during input remapping. Default: True.
|
||||
|
||||
Examples:
|
||||
>>> # Example 1: Remap 'gt_img' to 'img'
|
||||
>>> pipeline = [
|
||||
>>> # Use Remap to convert outer (original) field name 'gt_img'
|
||||
>>> # to inner (used by inner transforms) filed name 'img'
|
||||
>>> dict(type='Remap',
|
||||
>>> input_mapping=dict(img='gt_img'),
|
||||
>>> # inplace=True means output key mapping is the revert of
|
||||
>>> # the input key mapping, e.g. inner 'img' will be mapped
|
||||
>>> # back to outer 'gt_img'
|
||||
>>> inplace=True,
|
||||
>>> transforms=[
|
||||
>>> # In all transforms' implementation just use 'img'
|
||||
>>> # as a standard field name
|
||||
>>> dict(type='Crop', crop_size=(384, 384)),
|
||||
>>> dict(type='Normalize'),
|
||||
>>> ])
|
||||
>>> ]
|
||||
>>> # Example 2: Collect and structure multiple items
|
||||
>>> pipeline = [
|
||||
>>> # The inner field 'imgs' will be a dict with keys 'img_src'
|
||||
>>> # and 'img_tar', whose values are outer fields 'img1' and
|
||||
>>> # 'img2' respectively.
|
||||
>>> dict(type='Remap',
|
||||
>>> dict(
|
||||
>>> type='Remap',
|
||||
>>> input_mapping=dict(
|
||||
>>> imgs=dict(
|
||||
>>> img_src='img1',
|
||||
>>> img_tar='img2')),
|
||||
>>> transforms=...)
|
||||
>>> ]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
transforms: List[Union[Dict, Callable[[Dict], Dict]]],
|
||||
input_mapping: Optional[Dict] = None,
|
||||
output_mapping: Optional[Dict] = None,
|
||||
inplace: bool = False,
|
||||
strict: bool = True):
|
||||
|
||||
self.inplace = inplace
|
||||
self.strict = strict
|
||||
self.input_mapping = input_mapping
|
||||
|
||||
if self.inplace:
|
||||
if not self.strict:
|
||||
raise ValueError('Remap: `strict` must be set True if'
|
||||
'`inplace` is set True.')
|
||||
|
||||
if output_mapping is not None:
|
||||
raise ValueError('Remap: `output_mapping` must be None if'
|
||||
'`inplace` is set True.')
|
||||
self.output_mapping = input_mapping
|
||||
else:
|
||||
self.output_mapping = output_mapping
|
||||
|
||||
self.transforms = Compose(transforms)
|
||||
|
||||
def __iter__(self):
|
||||
"""Allow easy iteration over the transform sequence."""
|
||||
return iter(self.transforms)
|
||||
|
||||
def remap_input(self, data: Dict, input_mapping: Dict) -> Dict[str, Any]:
|
||||
"""Remap inputs for the wrapped transforms by gathering and renaming
|
||||
data items according to the input_mapping.
|
||||
|
||||
Args:
|
||||
data (dict): The original input data
|
||||
input_mapping (dict): The input key mapping. See the document of
|
||||
`mmcv.transforms.wrappers.Remap` for details.
|
||||
|
||||
Returns:
|
||||
dict: The input data with remapped keys. This will be the actual
|
||||
input of the wrapped pipeline.
|
||||
"""
|
||||
|
||||
def _remap(data, m):
|
||||
if isinstance(m, dict):
|
||||
# m is a dict {inner_key:outer_key, ...}
|
||||
return {k_in: _remap(data, k_out) for k_in, k_out in m.items()}
|
||||
if isinstance(m, (tuple, list)):
|
||||
# m is a list or tuple [outer_key1, outer_key2, ...]
|
||||
# This is the case when we collect items from the original
|
||||
# data to form a list or tuple to feed to the wrapped
|
||||
# transforms.
|
||||
return m.__class__(_remap(data, e) for e in m)
|
||||
|
||||
# m is an outer_key
|
||||
if self.strict:
|
||||
return data.get(m)
|
||||
else:
|
||||
return data.get(m, NotInResults)
|
||||
|
||||
collected = _remap(data, input_mapping)
|
||||
collected = {
|
||||
k: v
|
||||
for k, v in collected.items() if v is not NotInResults
|
||||
}
|
||||
|
||||
# Retain unmapped items
|
||||
inputs = data.copy()
|
||||
inputs.update(collected)
|
||||
|
||||
return inputs
|
||||
|
||||
def remap_output(self, data: Dict, output_mapping: Dict) -> Dict[str, Any]:
|
||||
"""Remap outputs from the wrapped transforms by gathering and renaming
|
||||
data items according to the output_mapping.
|
||||
|
||||
Args:
|
||||
data (dict): The output of the wrapped pipeline.
|
||||
output_mapping (dict): The output key mapping. See the document of
|
||||
`mmcv.transforms.wrappers.Remap` for details.
|
||||
|
||||
Returns:
|
||||
dict: The output with remapped keys.
|
||||
"""
|
||||
|
||||
def _remap(data, m):
|
||||
if isinstance(m, dict):
|
||||
assert isinstance(data, dict)
|
||||
results = {}
|
||||
for k_in, k_out in m.items():
|
||||
assert k_in in data
|
||||
results.update(_remap(data[k_in], k_out))
|
||||
return results
|
||||
if isinstance(m, (list, tuple)):
|
||||
assert isinstance(data, (list, tuple))
|
||||
assert len(data) == len(m)
|
||||
results = {}
|
||||
for m_i, d_i in zip(m, data):
|
||||
results.update(_remap(d_i, m_i))
|
||||
return results
|
||||
|
||||
return {m: data}
|
||||
|
||||
# Note that unmapped items are not retained, which is different from
|
||||
# the behavior in remap_input. This is to avoid original data items
|
||||
# being overwritten by intermediate namesakes
|
||||
return _remap(data, output_mapping)
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
|
||||
inputs = self.remap_input(results, self.input_mapping)
|
||||
outputs = self.transforms(inputs)
|
||||
|
||||
if self.output_mapping:
|
||||
outputs = self.remap_output(outputs, self.output_mapping)
|
||||
|
||||
results.update(outputs)
|
||||
return results
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ApplyToMultiple(Remap):
|
||||
"""A transform wrapper to apply the wrapped transforms to multiple data
|
||||
items. For example, apply Resize to multiple images.
|
||||
|
||||
Args:
|
||||
transforms (list[dict | callable]): Sequence of transform object or
|
||||
config dict to be wrapped.
|
||||
input_mapping (dict): A dict that defines the input key mapping.
|
||||
Note that to apply the transforms to multiple data items, the
|
||||
outer keys of the target items should be remapped as a list with
|
||||
the standard inner key (The key required by the wrapped transform).
|
||||
See the following example and the document of
|
||||
`mmcv.transforms.wrappers.Remap` for details.
|
||||
output_mapping (dict): A dict that defines the output key mapping.
|
||||
The keys and values have the same meanings and rules as in the
|
||||
`input_mapping`. Default: None.
|
||||
inplace (bool): If True, an inverse of the input_mapping will be used
|
||||
as the output_mapping. Note that if inplace is set True,
|
||||
output_mapping should be None and strict should be True.
|
||||
Default: False.
|
||||
strict (bool): If True, the outer keys in the input_mapping must exist
|
||||
in the input data, or an exception will be raised. If False,
|
||||
the missing keys will be assigned a special value `NotInResults`
|
||||
during input remapping. Default: True.
|
||||
share_random_params (bool): If True, the random transform
|
||||
(e.g., RandomFlip) will be conducted in a deterministic way and
|
||||
have the same behavior on all data items. For example, to randomly
|
||||
flip either both input image and ground-truth image, or none.
|
||||
Default: False.
|
||||
|
||||
.. note::
|
||||
To apply the transforms to each elements of a list or tuple, instead
|
||||
of separating data items, you can remap the outer key of the target
|
||||
sequence to the standard inner key. See example 2.
|
||||
example.
|
||||
|
||||
Examples:
|
||||
>>> # Example 1:
|
||||
>>> pipeline = [
|
||||
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img
|
||||
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img
|
||||
>>> # ApplyToMultiple maps multiple outer fields to standard the
|
||||
>>> # inner field and process them with wrapped transforms
|
||||
>>> # respectively
|
||||
>>> dict(type='ApplyToMultiple',
|
||||
>>> # case 1: from multiple outer fields
|
||||
>>> input_mapping=dict(img=['lq', 'gt']),
|
||||
>>> inplace=True,
|
||||
>>> # share_random_param=True means using identical random
|
||||
>>> # parameters in every processing
|
||||
>>> share_random_param=True,
|
||||
>>> transforms=[
|
||||
>>> dict(type='Crop', crop_size=(384, 384)),
|
||||
>>> dict(type='Normalize'),
|
||||
>>> ])
|
||||
>>> ]
|
||||
>>> # Example 2:
|
||||
>>> pipeline = [
|
||||
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img
|
||||
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img
|
||||
>>> # ApplyToMultiple maps multiple outer fields to standard the
|
||||
>>> # inner field and process them with wrapped transforms
|
||||
>>> # respectively
|
||||
>>> dict(type='ApplyToMultiple',
|
||||
>>> # case 2: from one outer field that contains multiple
|
||||
>>> # data elements (e.g. a list)
|
||||
>>> # input_mapping=dict(img='images'),
|
||||
>>> inplace=True,
|
||||
>>> share_random_param=True,
|
||||
>>> transforms=[
|
||||
>>> dict(type='Crop', crop_size=(384, 384)),
|
||||
>>> dict(type='Normalize'),
|
||||
>>> ])
|
||||
>>> ]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
transforms: List[Union[Dict, Callable[[Dict], Dict]]],
|
||||
input_mapping: Optional[Dict] = None,
|
||||
output_mapping: Optional[Dict] = None,
|
||||
inplace: bool = False,
|
||||
strict: bool = True,
|
||||
share_random_params: bool = False):
|
||||
super().__init__(transforms, input_mapping, output_mapping, inplace,
|
||||
strict)
|
||||
|
||||
self.share_random_params = share_random_params
|
||||
|
||||
def scatter_sequence(self, data: Dict) -> List[Dict]:
|
||||
# infer split number from input
|
||||
seq_len = None
|
||||
key_rep = None
|
||||
for key in self.input_mapping:
|
||||
|
||||
assert isinstance(data[key], Sequence)
|
||||
if seq_len is not None:
|
||||
if len(data[key]) != seq_len:
|
||||
raise ValueError('Got inconsistent sequence length: '
|
||||
f'{seq_len} ({key_rep}) vs. '
|
||||
f'{len(data[key])} ({key})')
|
||||
else:
|
||||
seq_len = len(data[key])
|
||||
key_rep = key
|
||||
|
||||
scatters = []
|
||||
for i in range(seq_len):
|
||||
scatter = data.copy()
|
||||
for key in self.input_mapping:
|
||||
scatter[key] = data[key][i]
|
||||
scatters.append(scatter)
|
||||
return scatters
|
||||
|
||||
def transform(self, results: Dict):
|
||||
# Apply input remapping
|
||||
inputs = self.remap_input(results, self.input_mapping)
|
||||
|
||||
# Scatter sequential inputs into a list
|
||||
inputs = self.scatter_sequence(inputs)
|
||||
|
||||
# Control random parameter sharing with a context manager
|
||||
if self.share_random_params:
|
||||
# The context manager :func`:cache_random_params` will let
|
||||
# cacheable method of the transforms cache their outputs. Thus
|
||||
# the random parameters will only generated once and shared
|
||||
# by all data items.
|
||||
ctx = cache_random_params
|
||||
else:
|
||||
ctx = nullcontext
|
||||
|
||||
with ctx(self.transforms):
|
||||
outputs = [self.transforms(_input) for _input in inputs]
|
||||
|
||||
# Collate output scatters (list of dict to dict of list)
|
||||
outputs = {
|
||||
key: [_output[key] for _output in outputs]
|
||||
for key in outputs[0]
|
||||
}
|
||||
|
||||
# Apply output remapping
|
||||
if self.output_mapping:
|
||||
outputs = self.remap_output(outputs, self.output_mapping)
|
||||
|
||||
results.update(outputs)
|
||||
return results
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomChoice(BaseTransform):
|
||||
"""Process data with a randomly chosen pipeline from given candidates.
|
||||
|
||||
Args:
|
||||
pipelines (list[list]): A list of pipeline candidates, each is a
|
||||
sequence of transforms.
|
||||
pipeline_probs (list[float], optional): The probabilities associated
|
||||
with each pipeline. The length should be equal to the pipeline
|
||||
number and the sum should be 1. If not given, a uniform
|
||||
distribution will be assumed.
|
||||
|
||||
Examples:
|
||||
>>> # config
|
||||
>>> pipeline = [
|
||||
>>> dict(type='RandomChoice',
|
||||
>>> pipelines=[
|
||||
>>> [dict(type='RandomHorizontalFlip')], # subpipeline 1
|
||||
>>> [dict(type='RandomRotate')], # subpipeline 2
|
||||
>>> ]
|
||||
>>> )
|
||||
>>> ]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pipelines: List[List[Union[Dict, Callable[[Dict], Dict]]]],
|
||||
pipeline_probs: Optional[List[float]] = None):
|
||||
|
||||
if pipeline_probs is not None:
|
||||
assert mmcv.is_seq_of(pipeline_probs, float)
|
||||
assert len(pipelines) == len(pipeline_probs), \
|
||||
'`pipelines` and `pipeline_probs` must have same lengths. ' \
|
||||
f'Got {len(pipelines)} vs {len(pipeline_probs)}.'
|
||||
assert sum(pipeline_probs) == 1
|
||||
|
||||
self.pipeline_probs = pipeline_probs
|
||||
self.pipelines = [Compose(transforms) for transforms in pipelines]
|
||||
|
||||
def transform(self, results):
|
||||
pipeline = np.random.choice(self.pipelines, p=self.pipeline_probs)
|
||||
return pipeline(results)
|
370
tests/test_transform/test_transform_wrapper.py
Normal file
370
tests/test_transform/test_transform_wrapper.py
Normal file
@ -0,0 +1,370 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mmcv.transform.base import BaseTransform
|
||||
from mmcv.transform.builder import TRANSFORMS
|
||||
from mmcv.transform.utils import cache_random_params, cacheable_method
|
||||
from mmcv.transform.wrappers import (ApplyToMultiple, Compose, RandomChoice,
|
||||
Remap)
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class AddToValue(BaseTransform):
|
||||
"""Dummy transform to test transform wrappers."""
|
||||
|
||||
def __init__(self, constant_addend=0, use_random_addend=False) -> None:
|
||||
super().__init__()
|
||||
self.constant_addend = constant_addend
|
||||
self.use_random_addend = use_random_addend
|
||||
|
||||
@cacheable_method
|
||||
def get_random_addend(self):
|
||||
return np.random.rand()
|
||||
|
||||
def transform(self, results):
|
||||
augend = results['value']
|
||||
|
||||
if isinstance(augend, list):
|
||||
warnings.warn('value is a list', UserWarning)
|
||||
if isinstance(augend, dict):
|
||||
warnings.warn('value is a dict', UserWarning)
|
||||
|
||||
def _add_to_value(augend, addend):
|
||||
if isinstance(augend, list):
|
||||
return [_add_to_value(v, addend) for v in augend]
|
||||
if isinstance(augend, dict):
|
||||
return {k: _add_to_value(v, addend) for k, v in augend.items()}
|
||||
return augend + addend
|
||||
|
||||
if self.use_random_addend:
|
||||
addend = self.get_random_addend()
|
||||
else:
|
||||
addend = self.constant_addend
|
||||
|
||||
results['value'] = _add_to_value(results['value'], addend)
|
||||
return results
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class SumTwoValues(BaseTransform):
|
||||
"""Dummy transform to test transform wrappers."""
|
||||
|
||||
def transform(self, results):
|
||||
if 'num_1' in results and 'num_2' in results:
|
||||
results['sum'] = results['num_1'] + results['num_2']
|
||||
else:
|
||||
results['sum'] = np.nan
|
||||
return results
|
||||
|
||||
|
||||
def test_compose():
|
||||
|
||||
# Case 1: build from cfg
|
||||
pipeline = [dict(type='AddToValue')]
|
||||
pipeline = Compose(pipeline)
|
||||
_ = str(pipeline)
|
||||
|
||||
# Case 2: build from transform list
|
||||
pipeline = [AddToValue()]
|
||||
pipeline = Compose(pipeline)
|
||||
|
||||
# Case 3: invalid build arguments
|
||||
pipeline = [[dict(type='AddToValue')]]
|
||||
with pytest.raises(TypeError):
|
||||
pipeline = Compose(pipeline)
|
||||
|
||||
# Case 4: contain transform with None output
|
||||
class DummyTransform(BaseTransform):
|
||||
|
||||
def transform(self, results):
|
||||
return None
|
||||
|
||||
pipeline = Compose([DummyTransform()])
|
||||
results = pipeline({})
|
||||
assert results is None
|
||||
|
||||
|
||||
def test_cache_random_parameters():
|
||||
|
||||
transform = AddToValue(use_random_addend=True)
|
||||
|
||||
# Case 1: cache random parameters
|
||||
assert hasattr(AddToValue, '_cacheable_methods')
|
||||
assert 'get_random_addend' in AddToValue._cacheable_methods
|
||||
|
||||
with cache_random_params(transform):
|
||||
results_1 = transform(dict(value=0))
|
||||
results_2 = transform(dict(value=0))
|
||||
np.testing.assert_equal(results_1['value'], results_2['value'])
|
||||
|
||||
# Case 2: do not cache random parameters
|
||||
results_1 = transform(dict(value=0))
|
||||
results_2 = transform(dict(value=0))
|
||||
with pytest.raises(AssertionError):
|
||||
np.testing.assert_equal(results_1['value'], results_2['value'])
|
||||
|
||||
# Case 3: invalid use of cacheable methods
|
||||
with pytest.raises(RuntimeError):
|
||||
with cache_random_params(transform):
|
||||
_ = transform.get_random_addend()
|
||||
|
||||
# Case 4: apply on nested transforms
|
||||
transform = Compose([AddToValue(use_random_addend=True)])
|
||||
with cache_random_params(transform):
|
||||
results_1 = transform(dict(value=0))
|
||||
results_2 = transform(dict(value=0))
|
||||
np.testing.assert_equal(results_1['value'], results_2['value'])
|
||||
|
||||
|
||||
def test_remap():
|
||||
|
||||
# Case 1: simple remap
|
||||
pipeline = Remap(
|
||||
transforms=[AddToValue(constant_addend=1)],
|
||||
input_mapping=dict(value='v_in'),
|
||||
output_mapping=dict(value='v_out'))
|
||||
|
||||
results = dict(value=0, v_in=1)
|
||||
results = pipeline(results)
|
||||
|
||||
np.testing.assert_equal(results['value'], 0) # should be unchanged
|
||||
np.testing.assert_equal(results['v_in'], 1)
|
||||
np.testing.assert_equal(results['v_out'], 2)
|
||||
|
||||
# Case 2: collecting list
|
||||
pipeline = Remap(
|
||||
transforms=[AddToValue(constant_addend=2)],
|
||||
input_mapping=dict(value=['v_in_1', 'v_in_2']),
|
||||
output_mapping=dict(value=['v_out_1', 'v_out_2']))
|
||||
results = dict(value=0, v_in_1=1, v_in_2=2)
|
||||
|
||||
with pytest.warns(UserWarning, match='value is a list'):
|
||||
results = pipeline(results)
|
||||
|
||||
np.testing.assert_equal(results['value'], 0) # should be unchanged
|
||||
np.testing.assert_equal(results['v_in_1'], 1)
|
||||
np.testing.assert_equal(results['v_in_2'], 2)
|
||||
np.testing.assert_equal(results['v_out_1'], 3)
|
||||
np.testing.assert_equal(results['v_out_2'], 4)
|
||||
|
||||
# Case 3: collecting dict
|
||||
pipeline = Remap(
|
||||
transforms=[AddToValue(constant_addend=2)],
|
||||
input_mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
|
||||
output_mapping=dict(value=dict(v1='v_out_1', v2='v_out_2')))
|
||||
results = dict(value=0, v_in_1=1, v_in_2=2)
|
||||
|
||||
with pytest.warns(UserWarning, match='value is a dict'):
|
||||
results = pipeline(results)
|
||||
|
||||
np.testing.assert_equal(results['value'], 0) # should be unchanged
|
||||
np.testing.assert_equal(results['v_in_1'], 1)
|
||||
np.testing.assert_equal(results['v_in_2'], 2)
|
||||
np.testing.assert_equal(results['v_out_1'], 3)
|
||||
np.testing.assert_equal(results['v_out_2'], 4)
|
||||
|
||||
# Case 4: collecting list with inplace mode
|
||||
pipeline = Remap(
|
||||
transforms=[AddToValue(constant_addend=2)],
|
||||
input_mapping=dict(value=['v_in_1', 'v_in_2']),
|
||||
inplace=True)
|
||||
results = dict(value=0, v_in_1=1, v_in_2=2)
|
||||
|
||||
with pytest.warns(UserWarning, match='value is a list'):
|
||||
results = pipeline(results)
|
||||
|
||||
np.testing.assert_equal(results['value'], 0)
|
||||
np.testing.assert_equal(results['v_in_1'], 3)
|
||||
np.testing.assert_equal(results['v_in_2'], 4)
|
||||
|
||||
# Case 5: collecting dict with inplace mode
|
||||
pipeline = Remap(
|
||||
transforms=[AddToValue(constant_addend=2)],
|
||||
input_mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
|
||||
inplace=True)
|
||||
results = dict(value=0, v_in_1=1, v_in_2=2)
|
||||
|
||||
with pytest.warns(UserWarning, match='value is a dict'):
|
||||
results = pipeline(results)
|
||||
|
||||
np.testing.assert_equal(results['value'], 0)
|
||||
np.testing.assert_equal(results['v_in_1'], 3)
|
||||
np.testing.assert_equal(results['v_in_2'], 4)
|
||||
|
||||
# Case 6: nested collection with inplace mode
|
||||
pipeline = Remap(
|
||||
transforms=[AddToValue(constant_addend=2)],
|
||||
input_mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]),
|
||||
inplace=True)
|
||||
results = dict(value=0, v1=1, v21=2, v22=3, v3=4)
|
||||
|
||||
with pytest.warns(UserWarning, match='value is a list'):
|
||||
results = pipeline(results)
|
||||
|
||||
np.testing.assert_equal(results['value'], 0)
|
||||
np.testing.assert_equal(results['v1'], 3)
|
||||
np.testing.assert_equal(results['v21'], 4)
|
||||
np.testing.assert_equal(results['v22'], 5)
|
||||
np.testing.assert_equal(results['v3'], 6)
|
||||
|
||||
# Case 7: `strict` must be True if `inplace` is set True
|
||||
with pytest.raises(ValueError):
|
||||
pipeline = Remap(
|
||||
transforms=[AddToValue(constant_addend=2)],
|
||||
input_mapping=dict(value=['v_in_1', 'v_in_2']),
|
||||
inplace=True,
|
||||
strict=False)
|
||||
|
||||
# Case 8: output_map must be None if `inplace` is set True
|
||||
with pytest.raises(ValueError):
|
||||
pipeline = Remap(
|
||||
transforms=[AddToValue(constant_addend=1)],
|
||||
input_mapping=dict(value='v_in'),
|
||||
output_mapping=dict(value='v_out'),
|
||||
inplace=True)
|
||||
|
||||
# Case 9: non-strict input mapping
|
||||
pipeline = Remap(
|
||||
transforms=[SumTwoValues()],
|
||||
input_mapping=dict(num_1='a', num_2='b'),
|
||||
strict=False)
|
||||
|
||||
results = pipeline(dict(a=1, b=2))
|
||||
np.testing.assert_equal(results['sum'], 3)
|
||||
|
||||
results = pipeline(dict(a=1))
|
||||
assert np.isnan(results['sum'])
|
||||
|
||||
# Test basic functions
|
||||
pipeline = Remap(
|
||||
transforms=[AddToValue(constant_addend=1)],
|
||||
input_mapping=dict(value='v_in'),
|
||||
output_mapping=dict(value='v_out'))
|
||||
|
||||
# __iter__
|
||||
for _ in pipeline:
|
||||
pass
|
||||
|
||||
# __repr__
|
||||
_ = str(pipeline)
|
||||
|
||||
|
||||
def test_apply_to_multiple():
|
||||
|
||||
# Case 1: apply to list in results
|
||||
pipeline = ApplyToMultiple(
|
||||
transforms=[AddToValue(constant_addend=1)],
|
||||
input_mapping=dict(value='values'),
|
||||
inplace=True)
|
||||
results = dict(values=[1, 2])
|
||||
|
||||
results = pipeline(results)
|
||||
|
||||
np.testing.assert_equal(results['values'], [2, 3])
|
||||
|
||||
# Case 2: apply to multiple keys
|
||||
pipeline = ApplyToMultiple(
|
||||
transforms=[AddToValue(constant_addend=1)],
|
||||
input_mapping=dict(value=['v_1', 'v_2']),
|
||||
inplace=True)
|
||||
results = dict(v_1=1, v_2=2)
|
||||
|
||||
results = pipeline(results)
|
||||
|
||||
np.testing.assert_equal(results['v_1'], 2)
|
||||
np.testing.assert_equal(results['v_2'], 3)
|
||||
|
||||
# Case 3: apply to multiple groups of keys
|
||||
pipeline = ApplyToMultiple(
|
||||
transforms=[SumTwoValues()],
|
||||
input_mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
|
||||
output_mapping=dict(sum=['a', 'b']))
|
||||
|
||||
results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
|
||||
results = pipeline(results)
|
||||
|
||||
np.testing.assert_equal(results['a'], 3)
|
||||
np.testing.assert_equal(results['b'], 7)
|
||||
|
||||
# Case 4: inconsistent sequence length
|
||||
with pytest.raises(ValueError):
|
||||
pipeline = ApplyToMultiple(
|
||||
transforms=[SumTwoValues()],
|
||||
input_mapping=dict(num_1='list_1', num_2='list_2'))
|
||||
|
||||
results = dict(list_1=[1, 2], list_2=[1, 2, 3])
|
||||
_ = pipeline(results)
|
||||
|
||||
# Case 5: share random parameter
|
||||
pipeline = ApplyToMultiple(
|
||||
transforms=[AddToValue(use_random_addend=True)],
|
||||
input_mapping=dict(value='values'),
|
||||
inplace=True,
|
||||
share_random_params=True,
|
||||
)
|
||||
|
||||
results = dict(values=[0, 0])
|
||||
results = pipeline(results)
|
||||
|
||||
np.testing.assert_equal(results['values'][0], results['values'][1])
|
||||
|
||||
# Test repr
|
||||
_ = str(pipeline)
|
||||
|
||||
|
||||
def test_randomchoice():
|
||||
|
||||
# Case 1: given probability
|
||||
pipeline = RandomChoice(
|
||||
pipelines=[[AddToValue(constant_addend=1.0)],
|
||||
[AddToValue(constant_addend=2.0)]],
|
||||
pipeline_probs=[1.0, 0.0])
|
||||
|
||||
results = pipeline(dict(value=1))
|
||||
np.testing.assert_equal(results['value'], 2.0)
|
||||
|
||||
# Case 1: default probability
|
||||
pipeline = RandomChoice(pipelines=[[AddToValue(
|
||||
constant_addend=1.0)], [AddToValue(constant_addend=2.0)]])
|
||||
|
||||
_ = pipeline(dict(value=1))
|
||||
|
||||
|
||||
def test_utils():
|
||||
# Test cacheable_method: normal case
|
||||
class DummyTransform(BaseTransform):
|
||||
|
||||
@cacheable_method
|
||||
def func(self):
|
||||
return np.random.rand()
|
||||
|
||||
def transform(self, results):
|
||||
_ = self.func()
|
||||
return results
|
||||
|
||||
transform = DummyTransform()
|
||||
_ = transform({})
|
||||
with cache_random_params(transform):
|
||||
_ = transform({})
|
||||
|
||||
# Test cacheable_method: invalid function type
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
class DummyTransform():
|
||||
|
||||
@cacheable_method
|
||||
@staticmethod
|
||||
def func():
|
||||
return np.random.rand()
|
||||
|
||||
# Test cacheable_method: invalid function argument list
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
class DummyTransform():
|
||||
|
||||
@cacheable_method
|
||||
def func(cls):
|
||||
return np.random.rand()
|
Loading…
x
Reference in New Issue
Block a user