[Feature] Add base transform interface (#1538)

* Support deepcopy for Config (#1658)

* Support deepcopy for Config

* Iterate the `__dict__` of Config directly.

* Use __new__ to avoid unnecessary initialization.

* Improve according to comments

* [Feature] Add spconv ops from mmdet3d (#1581)

* add ops (spconv) of mmdet3d

* fix typo

* refactor code

* resolve comments in #1452

* fix compile error

* fix bugs

* fix bug

* transform from 'types.h' to 'extension.h'

* fix bug

* transform from 'types.h' to 'extension.h' in parrots

* add extension.h in pybind.cpp

* add unittest

* Recover code

* (1) Remove prettyprint.h
(2) Switch `T` to `scalar_t`
(3) Remove useless lines
(4) Refine example in docstring of sparse_modules.py

* (1) rename from `cu.h` to `cuh`
(2) remove useless files
(3) move cpu files to `pytorch/cpu`

* reorganize files

* Add docstring for sparse_functional.py

* use dispatcher

* remove template

* use dispatch in cuda ops

* resolve Segmentation fault

* remove useless files

* fix lint

* fix lint

* fix lint

* fix unittest in test_build_layers.py

* add tensorview into include_dirs when compiling

* recover all deleted files

* fix lint and comments

* recover setup.py

* replace tv::GPU as tv::TorchGPU & support device guard

* fix lint

Co-authored-by: hdc <hudingchang.vendor@sensetime.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>

* Imporve the docstring of imfrombytes and fix a deprecation-warning (#1731)

* [Refactor] Refactor the interface for RoIAlignRotated (#1662)

* fix interface for RoIAlignRotated

* Add a unit test for RoIAlignRotated

* Make a unit test for RoIAlignRotated concise

* fix interface for RoIAlignRotated

* Refactor ext_module.nms_rotated

* Lint cpp files

* add transforms

* add invoking time check for cacheable methods

* fix lint

* add unittest

* fix bug in non-strict input mapping

* fix ci

* fix ci

* fix compatibility with python<3.9

* fix typing compatibility

* fix import

* fix typing

* add alternative for nullcontext

* fix import

* fix import

* add docstrings

* add docstrings

* fix callable check

* resolve comments

* fix lint

* enrich unittest cases

* fix lint

* fix unittest

Co-authored-by: Ma Zerun <mzr1996@163.com>
Co-authored-by: Wenhao Wu <79644370+wHao-Wu@users.noreply.github.com>
Co-authored-by: hdc <hudingchang.vendor@sensetime.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>
Co-authored-by: Jiazhen Wang <47851024+teamwong111@users.noreply.github.com>
Co-authored-by: Hakjin Lee <nijkah@gmail.com>
This commit is contained in:
Yining Li 2022-02-22 00:28:42 +08:00 committed by zhouzaida
parent 8b47579e7c
commit d00b0cec74
9 changed files with 1036 additions and 0 deletions

View File

@ -47,3 +47,8 @@ ops
------
.. automodule:: mmcv.ops
:members:
transform
---------
.. automodule:: mmcv.transform
:members:

View File

@ -47,3 +47,8 @@ ops
------
.. automodule:: mmcv.ops
:members:
transform
---------
.. automodule:: mmcv.transform
:members:

View File

@ -3,6 +3,7 @@
from .arraymisc import *
from .fileio import *
from .image import *
from .transform import *
from .utils import *
from .version import *
from .video import *

View File

@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import TRANSFORMS
from .wrappers import ApplyToMultiple, Compose, RandomChoice, Remap
__all__ = ['TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap']

27
mmcv/transform/base.py Normal file
View File

@ -0,0 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Dict
class BaseTransform(metaclass=ABCMeta):
def __call__(self, results: Dict) -> Dict:
return self.transform(results)
@abstractmethod
def transform(self, results: Dict) -> Dict:
"""The transform function. All subclass of BaseTransform should
override this method.
This function takes the result dict as the input, and can add new
items to the dict or modify existing items in the dict. And the result
dict will be returned in the end, which allows to concate multiple
transforms into a pipeline.
Args:
results (dict): The result dict.
Returns:
dict: The result dict.
"""

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from ..utils.registry import Registry
TRANSFORMS = Registry('transform')

162
mmcv/transform/utils.py Normal file
View File

@ -0,0 +1,162 @@
# Copyright (c) OpenMMLab. All rights reserved.
import functools
import inspect
import weakref
from collections import defaultdict
from collections.abc import Iterable
from contextlib import contextmanager
from typing import Callable, Union
from .base import BaseTransform
class cacheable_method:
"""Decorator that marks a method of a transform class as a cacheable
method.
This decorator is usually used together with the context-manager
:func`:cache_random_params`. In this context, a cacheable method will
cache its return value(s) at the first time of being invoked, and always
return the cached values when being invoked again.
.. note::
Only a instance method can be decorated as a cacheable_method.
"""
def __init__(self, func):
# Check `func` is to be bound as an instance method
if not inspect.isfunction(func):
raise TypeError('Unsupport callable to decorate with'
'@cacheable_method.')
func_args = inspect.getfullargspec(func).args
if len(func_args) == 0 or func_args[0] != 'self':
raise TypeError(
'@cacheable_method should only be used to decorate '
'instance methods (the first argument is `self`).')
functools.update_wrapper(self, func)
self.func = func
self.instance_ref = None
def __set_name__(self, owner, name):
# Maintain a record of decorated methods in the class
if not hasattr(owner, '_cacheable_methods'):
setattr(owner, '_cacheable_methods', [])
owner._cacheable_methods.append(self.__name__)
def __call__(self, *args, **kwargs):
# Get the transform instance whose method is decorated
# by cacheable_method
instance = self.instance_ref()
name = self.__name__
# Check the flag `self._cache_enabled`, which should be
# set by the contextmanagers like `cache_random_parameters`
cache_enabled = getattr(instance, '_cache_enabled', False)
if cache_enabled:
# Initialize the cache of the transform instances. The flag
# `cache_enabled` is set by contextmanagers like
# `cache_random_params`.
if not hasattr(instance, '_cache'):
setattr(instance, '_cache', {})
if name not in instance._cache:
instance._cache[name] = self.func(instance, *args, **kwargs)
# Return the cached value
return instance._cache[name]
else:
# Clear cache
if hasattr(instance, '_cache'):
del instance._cache
# Return function output
return self.func(instance, *args, **kwargs)
def __get__(self, obj, cls):
self.instance_ref = weakref.ref(obj)
return self
@contextmanager
def cache_random_params(transforms: Union[BaseTransform, Iterable]):
"""Context-manager that enables the cache of cacheable methods in
transforms.
In this mode, cacheable methods will cache their return values on the
first invoking, and always return the cached value afterward. This allow
to apply random transforms in a deterministic way. For example, apply same
transforms on multiple examples. See `cacheable_method` for more
information.
Args:
transforms (BaseTransform|list[BaseTransform]): The transforms to
enable cache.
"""
# key2method stores the original methods that are replaced by the wrapped
# ones. These methods will be restituted when exiting the context.
key2method = dict()
# key2counter stores the usage number of each cacheable_method. This is
# used to check that any cacheable_method is invoked once during processing
# on data sample.
key2counter = defaultdict(int)
def _add_counter(obj, method_name):
method = getattr(obj, method_name)
key = f'{id(obj)}.{method_name}'
key2method[key] = method
@functools.wraps(method)
def wrapped(*args, **kwargs):
key2counter[key] += 1
return method(*args, **kwargs)
return wrapped
def _start_cache(t: BaseTransform):
# Set cache enabled flag
setattr(t, '_cache_enabled', True)
# Store the original method and init the counter
if hasattr(t, '_cacheable_methods'):
setattr(t, 'transform', _add_counter(t, 'transform'))
for name in t._cacheable_methods:
setattr(t, name, _add_counter(t, name))
def _end_cache(t: BaseTransform):
# Remove cache enabled flag
del t._cache_enabled
if hasattr(t, '_cache'):
del t._cache
# Restore the original method
if hasattr(t, '_cacheable_methods'):
key_transform = f'{id(t)}.transform'
for name in t._cacheable_methods:
key = f'{id(t)}.{name}'
if key2counter[key] != key2counter[key_transform]:
raise RuntimeError(
'The cacheable method should be called once and only'
f'once during processing one data sample. {t} got'
f'unmatched number of {key2counter[key]} ({name}) vs'
f'{key2counter[key_transform]} (data samples)')
setattr(t, name, key2method[key])
setattr(t, 'transform', key2method[key_transform])
def _apply(t: Union[BaseTransform, Iterable],
func: Callable[[BaseTransform], None]):
if isinstance(t, BaseTransform):
if hasattr(t, '_cacheable_methods'):
func(t)
if isinstance(t, Iterable):
for _t in t:
_apply(_t, func)
try:
_apply(transforms, _start_cache)
yield
finally:
_apply(transforms, _end_cache)

457
mmcv/transform/wrappers.py Normal file
View File

@ -0,0 +1,457 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections.abc import Sequence
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import mmcv
from .base import BaseTransform
from .builder import TRANSFORMS
from .utils import cache_random_params
# Indicator for required but missing keys in results
NotInResults = object()
# Import nullcontext if python>=3.7, otherwise use a simple alternative
# implementation.
try:
from contextlib import nullcontext
except ImportError:
from contextlib import contextmanager
@contextmanager
def nullcontext(resource=None):
try:
yield resource
finally:
pass
class Compose(BaseTransform):
"""Compose multiple transforms sequentially.
Args:
transforms (list[dict | callable]): Sequence of transform object or
config dict to be composed.
Examples:
>>> pipeline = [
>>> dict(type='Compose',
>>> transforms=[
>>> dict(type='LoadImageFromFile'),
>>> dict(type='Normalize')
>>> ]
>>> )
>>> ]
"""
def __init__(self, transforms: List[Union[Dict, Callable[[Dict], Dict]]]):
assert isinstance(transforms, Sequence)
self.transforms = []
for transform in transforms:
if isinstance(transform, dict):
transform = TRANSFORMS.build(transform)
self.transforms.append(transform)
elif callable(transform):
self.transforms.append(transform)
else:
raise TypeError('transform must be callable or a dict, but got'
f' {type(transform)}')
def __iter__(self):
"""Allow easy iteration over the transform sequence."""
return iter(self.transforms)
def transform(self, results: Dict) -> Optional[Dict]:
"""Call function to apply transforms sequentially.
Args:
results (dict): A result dict contains the results to transform.
Returns:
dict or None: Transformed results.
"""
for t in self.transforms:
results = t(results)
if results is None:
return None
return results
def __repr__(self):
"""Compute the string representation."""
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += f'\n {t}'
format_string += '\n)'
return format_string
@TRANSFORMS.register_module()
class Remap(BaseTransform):
"""A transform wrapper to remap and reorganize the input/output of the
wrapped transforms (or sub-pipeline).
Args:
transforms (list[dict | callable]): Sequence of transform object or
config dict to be wrapped.
input_mapping (dict): A dict that defines the input key mapping.
The keys corresponds to the inner key (i.e., kwargs of the
`transform` method), and should be string type. The values
corresponds to the outer keys (i.e., the keys of the
data/results), and should have a type of string, list or dict.
None means not applying input mapping. Default: None.
output_mapping (dict): A dict that defines the output key mapping.
The keys and values have the same meanings and rules as in the
`input_mapping`. Default: None.
inplace (bool): If True, an inverse of the input_mapping will be used
as the output_mapping. Note that if inplace is set True,
output_mapping should be None and strict should be True.
Default: False.
strict (bool): If True, the outer keys in the input_mapping must exist
in the input data, or an exception will be raised. If False,
the missing keys will be assigned a special value `NotInResults`
during input remapping. Default: True.
Examples:
>>> # Example 1: Remap 'gt_img' to 'img'
>>> pipeline = [
>>> # Use Remap to convert outer (original) field name 'gt_img'
>>> # to inner (used by inner transforms) filed name 'img'
>>> dict(type='Remap',
>>> input_mapping=dict(img='gt_img'),
>>> # inplace=True means output key mapping is the revert of
>>> # the input key mapping, e.g. inner 'img' will be mapped
>>> # back to outer 'gt_img'
>>> inplace=True,
>>> transforms=[
>>> # In all transforms' implementation just use 'img'
>>> # as a standard field name
>>> dict(type='Crop', crop_size=(384, 384)),
>>> dict(type='Normalize'),
>>> ])
>>> ]
>>> # Example 2: Collect and structure multiple items
>>> pipeline = [
>>> # The inner field 'imgs' will be a dict with keys 'img_src'
>>> # and 'img_tar', whose values are outer fields 'img1' and
>>> # 'img2' respectively.
>>> dict(type='Remap',
>>> dict(
>>> type='Remap',
>>> input_mapping=dict(
>>> imgs=dict(
>>> img_src='img1',
>>> img_tar='img2')),
>>> transforms=...)
>>> ]
"""
def __init__(self,
transforms: List[Union[Dict, Callable[[Dict], Dict]]],
input_mapping: Optional[Dict] = None,
output_mapping: Optional[Dict] = None,
inplace: bool = False,
strict: bool = True):
self.inplace = inplace
self.strict = strict
self.input_mapping = input_mapping
if self.inplace:
if not self.strict:
raise ValueError('Remap: `strict` must be set True if'
'`inplace` is set True.')
if output_mapping is not None:
raise ValueError('Remap: `output_mapping` must be None if'
'`inplace` is set True.')
self.output_mapping = input_mapping
else:
self.output_mapping = output_mapping
self.transforms = Compose(transforms)
def __iter__(self):
"""Allow easy iteration over the transform sequence."""
return iter(self.transforms)
def remap_input(self, data: Dict, input_mapping: Dict) -> Dict[str, Any]:
"""Remap inputs for the wrapped transforms by gathering and renaming
data items according to the input_mapping.
Args:
data (dict): The original input data
input_mapping (dict): The input key mapping. See the document of
`mmcv.transforms.wrappers.Remap` for details.
Returns:
dict: The input data with remapped keys. This will be the actual
input of the wrapped pipeline.
"""
def _remap(data, m):
if isinstance(m, dict):
# m is a dict {inner_key:outer_key, ...}
return {k_in: _remap(data, k_out) for k_in, k_out in m.items()}
if isinstance(m, (tuple, list)):
# m is a list or tuple [outer_key1, outer_key2, ...]
# This is the case when we collect items from the original
# data to form a list or tuple to feed to the wrapped
# transforms.
return m.__class__(_remap(data, e) for e in m)
# m is an outer_key
if self.strict:
return data.get(m)
else:
return data.get(m, NotInResults)
collected = _remap(data, input_mapping)
collected = {
k: v
for k, v in collected.items() if v is not NotInResults
}
# Retain unmapped items
inputs = data.copy()
inputs.update(collected)
return inputs
def remap_output(self, data: Dict, output_mapping: Dict) -> Dict[str, Any]:
"""Remap outputs from the wrapped transforms by gathering and renaming
data items according to the output_mapping.
Args:
data (dict): The output of the wrapped pipeline.
output_mapping (dict): The output key mapping. See the document of
`mmcv.transforms.wrappers.Remap` for details.
Returns:
dict: The output with remapped keys.
"""
def _remap(data, m):
if isinstance(m, dict):
assert isinstance(data, dict)
results = {}
for k_in, k_out in m.items():
assert k_in in data
results.update(_remap(data[k_in], k_out))
return results
if isinstance(m, (list, tuple)):
assert isinstance(data, (list, tuple))
assert len(data) == len(m)
results = {}
for m_i, d_i in zip(m, data):
results.update(_remap(d_i, m_i))
return results
return {m: data}
# Note that unmapped items are not retained, which is different from
# the behavior in remap_input. This is to avoid original data items
# being overwritten by intermediate namesakes
return _remap(data, output_mapping)
def transform(self, results: Dict) -> Dict:
inputs = self.remap_input(results, self.input_mapping)
outputs = self.transforms(inputs)
if self.output_mapping:
outputs = self.remap_output(outputs, self.output_mapping)
results.update(outputs)
return results
@TRANSFORMS.register_module()
class ApplyToMultiple(Remap):
"""A transform wrapper to apply the wrapped transforms to multiple data
items. For example, apply Resize to multiple images.
Args:
transforms (list[dict | callable]): Sequence of transform object or
config dict to be wrapped.
input_mapping (dict): A dict that defines the input key mapping.
Note that to apply the transforms to multiple data items, the
outer keys of the target items should be remapped as a list with
the standard inner key (The key required by the wrapped transform).
See the following example and the document of
`mmcv.transforms.wrappers.Remap` for details.
output_mapping (dict): A dict that defines the output key mapping.
The keys and values have the same meanings and rules as in the
`input_mapping`. Default: None.
inplace (bool): If True, an inverse of the input_mapping will be used
as the output_mapping. Note that if inplace is set True,
output_mapping should be None and strict should be True.
Default: False.
strict (bool): If True, the outer keys in the input_mapping must exist
in the input data, or an exception will be raised. If False,
the missing keys will be assigned a special value `NotInResults`
during input remapping. Default: True.
share_random_params (bool): If True, the random transform
(e.g., RandomFlip) will be conducted in a deterministic way and
have the same behavior on all data items. For example, to randomly
flip either both input image and ground-truth image, or none.
Default: False.
.. note::
To apply the transforms to each elements of a list or tuple, instead
of separating data items, you can remap the outer key of the target
sequence to the standard inner key. See example 2.
example.
Examples:
>>> # Example 1:
>>> pipeline = [
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img
>>> # ApplyToMultiple maps multiple outer fields to standard the
>>> # inner field and process them with wrapped transforms
>>> # respectively
>>> dict(type='ApplyToMultiple',
>>> # case 1: from multiple outer fields
>>> input_mapping=dict(img=['lq', 'gt']),
>>> inplace=True,
>>> # share_random_param=True means using identical random
>>> # parameters in every processing
>>> share_random_param=True,
>>> transforms=[
>>> dict(type='Crop', crop_size=(384, 384)),
>>> dict(type='Normalize'),
>>> ])
>>> ]
>>> # Example 2:
>>> pipeline = [
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img
>>> # ApplyToMultiple maps multiple outer fields to standard the
>>> # inner field and process them with wrapped transforms
>>> # respectively
>>> dict(type='ApplyToMultiple',
>>> # case 2: from one outer field that contains multiple
>>> # data elements (e.g. a list)
>>> # input_mapping=dict(img='images'),
>>> inplace=True,
>>> share_random_param=True,
>>> transforms=[
>>> dict(type='Crop', crop_size=(384, 384)),
>>> dict(type='Normalize'),
>>> ])
>>> ]
"""
def __init__(self,
transforms: List[Union[Dict, Callable[[Dict], Dict]]],
input_mapping: Optional[Dict] = None,
output_mapping: Optional[Dict] = None,
inplace: bool = False,
strict: bool = True,
share_random_params: bool = False):
super().__init__(transforms, input_mapping, output_mapping, inplace,
strict)
self.share_random_params = share_random_params
def scatter_sequence(self, data: Dict) -> List[Dict]:
# infer split number from input
seq_len = None
key_rep = None
for key in self.input_mapping:
assert isinstance(data[key], Sequence)
if seq_len is not None:
if len(data[key]) != seq_len:
raise ValueError('Got inconsistent sequence length: '
f'{seq_len} ({key_rep}) vs. '
f'{len(data[key])} ({key})')
else:
seq_len = len(data[key])
key_rep = key
scatters = []
for i in range(seq_len):
scatter = data.copy()
for key in self.input_mapping:
scatter[key] = data[key][i]
scatters.append(scatter)
return scatters
def transform(self, results: Dict):
# Apply input remapping
inputs = self.remap_input(results, self.input_mapping)
# Scatter sequential inputs into a list
inputs = self.scatter_sequence(inputs)
# Control random parameter sharing with a context manager
if self.share_random_params:
# The context manager :func`:cache_random_params` will let
# cacheable method of the transforms cache their outputs. Thus
# the random parameters will only generated once and shared
# by all data items.
ctx = cache_random_params
else:
ctx = nullcontext
with ctx(self.transforms):
outputs = [self.transforms(_input) for _input in inputs]
# Collate output scatters (list of dict to dict of list)
outputs = {
key: [_output[key] for _output in outputs]
for key in outputs[0]
}
# Apply output remapping
if self.output_mapping:
outputs = self.remap_output(outputs, self.output_mapping)
results.update(outputs)
return results
@TRANSFORMS.register_module()
class RandomChoice(BaseTransform):
"""Process data with a randomly chosen pipeline from given candidates.
Args:
pipelines (list[list]): A list of pipeline candidates, each is a
sequence of transforms.
pipeline_probs (list[float], optional): The probabilities associated
with each pipeline. The length should be equal to the pipeline
number and the sum should be 1. If not given, a uniform
distribution will be assumed.
Examples:
>>> # config
>>> pipeline = [
>>> dict(type='RandomChoice',
>>> pipelines=[
>>> [dict(type='RandomHorizontalFlip')], # subpipeline 1
>>> [dict(type='RandomRotate')], # subpipeline 2
>>> ]
>>> )
>>> ]
"""
def __init__(self,
pipelines: List[List[Union[Dict, Callable[[Dict], Dict]]]],
pipeline_probs: Optional[List[float]] = None):
if pipeline_probs is not None:
assert mmcv.is_seq_of(pipeline_probs, float)
assert len(pipelines) == len(pipeline_probs), \
'`pipelines` and `pipeline_probs` must have same lengths. ' \
f'Got {len(pipelines)} vs {len(pipeline_probs)}.'
assert sum(pipeline_probs) == 1
self.pipeline_probs = pipeline_probs
self.pipelines = [Compose(transforms) for transforms in pipelines]
def transform(self, results):
pipeline = np.random.choice(self.pipelines, p=self.pipeline_probs)
return pipeline(results)

View File

@ -0,0 +1,370 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import numpy as np
import pytest
from mmcv.transform.base import BaseTransform
from mmcv.transform.builder import TRANSFORMS
from mmcv.transform.utils import cache_random_params, cacheable_method
from mmcv.transform.wrappers import (ApplyToMultiple, Compose, RandomChoice,
Remap)
@TRANSFORMS.register_module()
class AddToValue(BaseTransform):
"""Dummy transform to test transform wrappers."""
def __init__(self, constant_addend=0, use_random_addend=False) -> None:
super().__init__()
self.constant_addend = constant_addend
self.use_random_addend = use_random_addend
@cacheable_method
def get_random_addend(self):
return np.random.rand()
def transform(self, results):
augend = results['value']
if isinstance(augend, list):
warnings.warn('value is a list', UserWarning)
if isinstance(augend, dict):
warnings.warn('value is a dict', UserWarning)
def _add_to_value(augend, addend):
if isinstance(augend, list):
return [_add_to_value(v, addend) for v in augend]
if isinstance(augend, dict):
return {k: _add_to_value(v, addend) for k, v in augend.items()}
return augend + addend
if self.use_random_addend:
addend = self.get_random_addend()
else:
addend = self.constant_addend
results['value'] = _add_to_value(results['value'], addend)
return results
@TRANSFORMS.register_module()
class SumTwoValues(BaseTransform):
"""Dummy transform to test transform wrappers."""
def transform(self, results):
if 'num_1' in results and 'num_2' in results:
results['sum'] = results['num_1'] + results['num_2']
else:
results['sum'] = np.nan
return results
def test_compose():
# Case 1: build from cfg
pipeline = [dict(type='AddToValue')]
pipeline = Compose(pipeline)
_ = str(pipeline)
# Case 2: build from transform list
pipeline = [AddToValue()]
pipeline = Compose(pipeline)
# Case 3: invalid build arguments
pipeline = [[dict(type='AddToValue')]]
with pytest.raises(TypeError):
pipeline = Compose(pipeline)
# Case 4: contain transform with None output
class DummyTransform(BaseTransform):
def transform(self, results):
return None
pipeline = Compose([DummyTransform()])
results = pipeline({})
assert results is None
def test_cache_random_parameters():
transform = AddToValue(use_random_addend=True)
# Case 1: cache random parameters
assert hasattr(AddToValue, '_cacheable_methods')
assert 'get_random_addend' in AddToValue._cacheable_methods
with cache_random_params(transform):
results_1 = transform(dict(value=0))
results_2 = transform(dict(value=0))
np.testing.assert_equal(results_1['value'], results_2['value'])
# Case 2: do not cache random parameters
results_1 = transform(dict(value=0))
results_2 = transform(dict(value=0))
with pytest.raises(AssertionError):
np.testing.assert_equal(results_1['value'], results_2['value'])
# Case 3: invalid use of cacheable methods
with pytest.raises(RuntimeError):
with cache_random_params(transform):
_ = transform.get_random_addend()
# Case 4: apply on nested transforms
transform = Compose([AddToValue(use_random_addend=True)])
with cache_random_params(transform):
results_1 = transform(dict(value=0))
results_2 = transform(dict(value=0))
np.testing.assert_equal(results_1['value'], results_2['value'])
def test_remap():
# Case 1: simple remap
pipeline = Remap(
transforms=[AddToValue(constant_addend=1)],
input_mapping=dict(value='v_in'),
output_mapping=dict(value='v_out'))
results = dict(value=0, v_in=1)
results = pipeline(results)
np.testing.assert_equal(results['value'], 0) # should be unchanged
np.testing.assert_equal(results['v_in'], 1)
np.testing.assert_equal(results['v_out'], 2)
# Case 2: collecting list
pipeline = Remap(
transforms=[AddToValue(constant_addend=2)],
input_mapping=dict(value=['v_in_1', 'v_in_2']),
output_mapping=dict(value=['v_out_1', 'v_out_2']))
results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a list'):
results = pipeline(results)
np.testing.assert_equal(results['value'], 0) # should be unchanged
np.testing.assert_equal(results['v_in_1'], 1)
np.testing.assert_equal(results['v_in_2'], 2)
np.testing.assert_equal(results['v_out_1'], 3)
np.testing.assert_equal(results['v_out_2'], 4)
# Case 3: collecting dict
pipeline = Remap(
transforms=[AddToValue(constant_addend=2)],
input_mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
output_mapping=dict(value=dict(v1='v_out_1', v2='v_out_2')))
results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a dict'):
results = pipeline(results)
np.testing.assert_equal(results['value'], 0) # should be unchanged
np.testing.assert_equal(results['v_in_1'], 1)
np.testing.assert_equal(results['v_in_2'], 2)
np.testing.assert_equal(results['v_out_1'], 3)
np.testing.assert_equal(results['v_out_2'], 4)
# Case 4: collecting list with inplace mode
pipeline = Remap(
transforms=[AddToValue(constant_addend=2)],
input_mapping=dict(value=['v_in_1', 'v_in_2']),
inplace=True)
results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a list'):
results = pipeline(results)
np.testing.assert_equal(results['value'], 0)
np.testing.assert_equal(results['v_in_1'], 3)
np.testing.assert_equal(results['v_in_2'], 4)
# Case 5: collecting dict with inplace mode
pipeline = Remap(
transforms=[AddToValue(constant_addend=2)],
input_mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
inplace=True)
results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a dict'):
results = pipeline(results)
np.testing.assert_equal(results['value'], 0)
np.testing.assert_equal(results['v_in_1'], 3)
np.testing.assert_equal(results['v_in_2'], 4)
# Case 6: nested collection with inplace mode
pipeline = Remap(
transforms=[AddToValue(constant_addend=2)],
input_mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]),
inplace=True)
results = dict(value=0, v1=1, v21=2, v22=3, v3=4)
with pytest.warns(UserWarning, match='value is a list'):
results = pipeline(results)
np.testing.assert_equal(results['value'], 0)
np.testing.assert_equal(results['v1'], 3)
np.testing.assert_equal(results['v21'], 4)
np.testing.assert_equal(results['v22'], 5)
np.testing.assert_equal(results['v3'], 6)
# Case 7: `strict` must be True if `inplace` is set True
with pytest.raises(ValueError):
pipeline = Remap(
transforms=[AddToValue(constant_addend=2)],
input_mapping=dict(value=['v_in_1', 'v_in_2']),
inplace=True,
strict=False)
# Case 8: output_map must be None if `inplace` is set True
with pytest.raises(ValueError):
pipeline = Remap(
transforms=[AddToValue(constant_addend=1)],
input_mapping=dict(value='v_in'),
output_mapping=dict(value='v_out'),
inplace=True)
# Case 9: non-strict input mapping
pipeline = Remap(
transforms=[SumTwoValues()],
input_mapping=dict(num_1='a', num_2='b'),
strict=False)
results = pipeline(dict(a=1, b=2))
np.testing.assert_equal(results['sum'], 3)
results = pipeline(dict(a=1))
assert np.isnan(results['sum'])
# Test basic functions
pipeline = Remap(
transforms=[AddToValue(constant_addend=1)],
input_mapping=dict(value='v_in'),
output_mapping=dict(value='v_out'))
# __iter__
for _ in pipeline:
pass
# __repr__
_ = str(pipeline)
def test_apply_to_multiple():
# Case 1: apply to list in results
pipeline = ApplyToMultiple(
transforms=[AddToValue(constant_addend=1)],
input_mapping=dict(value='values'),
inplace=True)
results = dict(values=[1, 2])
results = pipeline(results)
np.testing.assert_equal(results['values'], [2, 3])
# Case 2: apply to multiple keys
pipeline = ApplyToMultiple(
transforms=[AddToValue(constant_addend=1)],
input_mapping=dict(value=['v_1', 'v_2']),
inplace=True)
results = dict(v_1=1, v_2=2)
results = pipeline(results)
np.testing.assert_equal(results['v_1'], 2)
np.testing.assert_equal(results['v_2'], 3)
# Case 3: apply to multiple groups of keys
pipeline = ApplyToMultiple(
transforms=[SumTwoValues()],
input_mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
output_mapping=dict(sum=['a', 'b']))
results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
results = pipeline(results)
np.testing.assert_equal(results['a'], 3)
np.testing.assert_equal(results['b'], 7)
# Case 4: inconsistent sequence length
with pytest.raises(ValueError):
pipeline = ApplyToMultiple(
transforms=[SumTwoValues()],
input_mapping=dict(num_1='list_1', num_2='list_2'))
results = dict(list_1=[1, 2], list_2=[1, 2, 3])
_ = pipeline(results)
# Case 5: share random parameter
pipeline = ApplyToMultiple(
transforms=[AddToValue(use_random_addend=True)],
input_mapping=dict(value='values'),
inplace=True,
share_random_params=True,
)
results = dict(values=[0, 0])
results = pipeline(results)
np.testing.assert_equal(results['values'][0], results['values'][1])
# Test repr
_ = str(pipeline)
def test_randomchoice():
# Case 1: given probability
pipeline = RandomChoice(
pipelines=[[AddToValue(constant_addend=1.0)],
[AddToValue(constant_addend=2.0)]],
pipeline_probs=[1.0, 0.0])
results = pipeline(dict(value=1))
np.testing.assert_equal(results['value'], 2.0)
# Case 1: default probability
pipeline = RandomChoice(pipelines=[[AddToValue(
constant_addend=1.0)], [AddToValue(constant_addend=2.0)]])
_ = pipeline(dict(value=1))
def test_utils():
# Test cacheable_method: normal case
class DummyTransform(BaseTransform):
@cacheable_method
def func(self):
return np.random.rand()
def transform(self, results):
_ = self.func()
return results
transform = DummyTransform()
_ = transform({})
with cache_random_params(transform):
_ = transform({})
# Test cacheable_method: invalid function type
with pytest.raises(TypeError):
class DummyTransform():
@cacheable_method
@staticmethod
def func():
return np.random.rand()
# Test cacheable_method: invalid function argument list
with pytest.raises(TypeError):
class DummyTransform():
@cacheable_method
def func(cls):
return np.random.rand()