mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] add testing utils (#475)
* add testing utils * fix ut * add blank line betweeen `Args` and `Returns`
This commit is contained in:
parent
7e423cf23f
commit
fba9a94f52
@ -1,4 +1,11 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .compare import assert_allclose
|
||||
from .compare import (assert_allclose, assert_attrs_equal,
|
||||
assert_dict_contains_subset, assert_dict_has_keys,
|
||||
assert_is_norm_layer, assert_keys_equal,
|
||||
assert_params_all_zeros, check_python_script)
|
||||
|
||||
__all__ = ['assert_allclose']
|
||||
__all__ = [
|
||||
'assert_allclose', 'assert_dict_contains_subset', 'assert_keys_equal',
|
||||
'assert_attrs_equal', 'assert_dict_has_keys', 'assert_is_norm_layer',
|
||||
'assert_params_all_zeros', 'check_python_script'
|
||||
]
|
||||
|
@ -1,10 +1,17 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Callable, Optional, Union
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from runpy import run_path
|
||||
from shlex import split
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
from torch.nn import GroupNorm, LayerNorm
|
||||
from torch.testing import assert_allclose as _assert_allclose
|
||||
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
|
||||
|
||||
|
||||
def assert_allclose(
|
||||
@ -47,3 +54,135 @@ def assert_allclose(
|
||||
# when PyTorch < 1.6
|
||||
_assert_allclose(
|
||||
actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan)
|
||||
|
||||
|
||||
def check_python_script(cmd):
|
||||
"""Run the python cmd script with `__main__`. The difference between
|
||||
`os.system` is that, this function exectues code in the current process, so
|
||||
that it can be tracked by coverage tools. Currently it supports two forms:
|
||||
|
||||
- ./tests/data/scripts/hello.py zz
|
||||
- python tests/data/scripts/hello.py zz
|
||||
"""
|
||||
args = split(cmd)
|
||||
if args[0] == 'python':
|
||||
args = args[1:]
|
||||
with patch.object(sys, 'argv', args):
|
||||
run_path(args[0], run_name='__main__')
|
||||
|
||||
|
||||
def _any(judge_result):
|
||||
"""Since built-in ``any`` works only when the element of iterable is not
|
||||
iterable, implement the function."""
|
||||
if not isinstance(judge_result, Iterable):
|
||||
return judge_result
|
||||
|
||||
try:
|
||||
for element in judge_result:
|
||||
if _any(element):
|
||||
return True
|
||||
except TypeError:
|
||||
# Maybe encounter the case: torch.tensor(True) | torch.tensor(False)
|
||||
if judge_result:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def assert_dict_contains_subset(dict_obj: Dict[Any, Any],
|
||||
expected_subset: Dict[Any, Any]) -> bool:
|
||||
"""Check if the dict_obj contains the expected_subset.
|
||||
|
||||
Args:
|
||||
dict_obj (Dict[Any, Any]): Dict object to be checked.
|
||||
expected_subset (Dict[Any, Any]): Subset expected to be contained in
|
||||
dict_obj.
|
||||
|
||||
Returns:
|
||||
bool: Whether the dict_obj contains the expected_subset.
|
||||
"""
|
||||
|
||||
for key, value in expected_subset.items():
|
||||
if key not in dict_obj.keys() or _any(dict_obj[key] != value):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool:
|
||||
"""Check if attribute of class object is correct.
|
||||
|
||||
Args:
|
||||
obj (object): Class object to be checked.
|
||||
expected_attrs (Dict[str, Any]): Dict of the expected attrs.
|
||||
|
||||
Returns:
|
||||
bool: Whether the attribute of class object is correct.
|
||||
"""
|
||||
for attr, value in expected_attrs.items():
|
||||
if not hasattr(obj, attr) or _any(getattr(obj, attr) != value):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def assert_dict_has_keys(obj: Dict[str, Any],
|
||||
expected_keys: List[str]) -> bool:
|
||||
"""Check if the obj has all the expected_keys.
|
||||
|
||||
Args:
|
||||
obj (Dict[str, Any]): Object to be checked.
|
||||
expected_keys (List[str]): Keys expected to contained in the keys of
|
||||
the obj.
|
||||
|
||||
Returns:
|
||||
bool: Whether the obj has the expected keys.
|
||||
"""
|
||||
return set(expected_keys).issubset(set(obj.keys()))
|
||||
|
||||
|
||||
def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool:
|
||||
"""Check if target_keys is equal to result_keys.
|
||||
|
||||
Args:
|
||||
result_keys (List[str]): Result keys to be checked.
|
||||
target_keys (List[str]): Target keys to be checked.
|
||||
|
||||
Returns:
|
||||
bool: Whether target_keys is equal to result_keys.
|
||||
"""
|
||||
return set(result_keys) == set(target_keys)
|
||||
|
||||
|
||||
def assert_is_norm_layer(module) -> bool:
|
||||
"""Check if the module is a norm layer.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The module to be checked.
|
||||
|
||||
Returns:
|
||||
bool: Whether the module is a norm layer.
|
||||
"""
|
||||
|
||||
norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)
|
||||
return isinstance(module, norm_layer_candidates)
|
||||
|
||||
|
||||
def assert_params_all_zeros(module) -> bool:
|
||||
"""Check if the parameters of the module is all zeros.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The module to be checked.
|
||||
|
||||
Returns:
|
||||
bool: Whether the parameters of the module is all zeros.
|
||||
"""
|
||||
weight_data = module.weight.data
|
||||
is_weight_zero = weight_data.allclose(
|
||||
weight_data.new_zeros(weight_data.size()))
|
||||
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
bias_data = module.bias.data
|
||||
is_bias_zero = bias_data.allclose(
|
||||
bias_data.new_zeros(bias_data.size()))
|
||||
else:
|
||||
is_bias_zero = True
|
||||
|
||||
return is_weight_zero and is_bias_zero
|
||||
|
25
tests/data/scripts/hello.py
Normal file
25
tests/data/scripts/hello.py
Normal file
@ -0,0 +1,25 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Say hello.')
|
||||
parser.add_argument('name', help='To whom.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
print(f'hello {args.name}!')
|
||||
if args.name == 'agent':
|
||||
warnings.warn('I have a secret!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
197
tests/test_testing/test_compare.py
Normal file
197
tests/test_testing/test_compare.py
Normal file
@ -0,0 +1,197 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mmengine.testing as testing
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
torch = None
|
||||
else:
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def test_assert_dict_contains_subset():
|
||||
dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6)}
|
||||
|
||||
# case 1
|
||||
expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6)}
|
||||
assert testing.assert_dict_contains_subset(dict_obj, expected_subset)
|
||||
|
||||
# case 2
|
||||
expected_subset = {'a': 'test1', 'b': 2, 'c': (6, 4)}
|
||||
assert not testing.assert_dict_contains_subset(dict_obj, expected_subset)
|
||||
|
||||
# case 3
|
||||
expected_subset = {'a': 'test1', 'b': 2, 'c': None}
|
||||
assert not testing.assert_dict_contains_subset(dict_obj, expected_subset)
|
||||
|
||||
# case 4
|
||||
expected_subset = {'a': 'test1', 'b': 2, 'd': (4, 6)}
|
||||
assert not testing.assert_dict_contains_subset(dict_obj, expected_subset)
|
||||
|
||||
# case 5
|
||||
dict_obj = {
|
||||
'a': 'test1',
|
||||
'b': 2,
|
||||
'c': (4, 6),
|
||||
'd': np.array([[5, 3, 5], [1, 2, 3]])
|
||||
}
|
||||
expected_subset = {
|
||||
'a': 'test1',
|
||||
'b': 2,
|
||||
'c': (4, 6),
|
||||
'd': np.array([[5, 3, 5], [6, 2, 3]])
|
||||
}
|
||||
assert not testing.assert_dict_contains_subset(dict_obj, expected_subset)
|
||||
|
||||
# case 6
|
||||
dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])}
|
||||
expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])}
|
||||
assert testing.assert_dict_contains_subset(dict_obj, expected_subset)
|
||||
|
||||
if torch is not None:
|
||||
dict_obj = {
|
||||
'a': 'test1',
|
||||
'b': 2,
|
||||
'c': (4, 6),
|
||||
'd': torch.tensor([5, 3, 5])
|
||||
}
|
||||
|
||||
# case 7
|
||||
expected_subset = {'d': torch.tensor([5, 5, 5])}
|
||||
assert not testing.assert_dict_contains_subset(dict_obj,
|
||||
expected_subset)
|
||||
|
||||
# case 8
|
||||
expected_subset = {'d': torch.tensor([[5, 3, 5], [4, 1, 2]])}
|
||||
assert not testing.assert_dict_contains_subset(dict_obj,
|
||||
expected_subset)
|
||||
|
||||
|
||||
def test_assert_attrs_equal():
|
||||
|
||||
class TestExample:
|
||||
a, b, c = 1, ('wvi', 3), [4.5, 3.14]
|
||||
|
||||
def test_func(self):
|
||||
return self.b
|
||||
|
||||
# case 1
|
||||
assert testing.assert_attrs_equal(TestExample, {
|
||||
'a': 1,
|
||||
'b': ('wvi', 3),
|
||||
'c': [4.5, 3.14]
|
||||
})
|
||||
|
||||
# case 2
|
||||
assert not testing.assert_attrs_equal(TestExample, {
|
||||
'a': 1,
|
||||
'b': ('wvi', 3),
|
||||
'c': [4.5, 3.14, 2]
|
||||
})
|
||||
|
||||
# case 3
|
||||
assert not testing.assert_attrs_equal(TestExample, {
|
||||
'bc': 54,
|
||||
'c': [4.5, 3.14]
|
||||
})
|
||||
|
||||
# case 4
|
||||
assert testing.assert_attrs_equal(TestExample, {
|
||||
'b': ('wvi', 3),
|
||||
'test_func': TestExample.test_func
|
||||
})
|
||||
|
||||
if torch is not None:
|
||||
|
||||
class TestExample:
|
||||
a, b = torch.tensor([1]), torch.tensor([4, 5])
|
||||
|
||||
# case 5
|
||||
assert testing.assert_attrs_equal(TestExample, {
|
||||
'a': torch.tensor([1]),
|
||||
'b': torch.tensor([4, 5])
|
||||
})
|
||||
|
||||
# case 6
|
||||
assert not testing.assert_attrs_equal(TestExample, {
|
||||
'a': torch.tensor([1]),
|
||||
'b': torch.tensor([4, 6])
|
||||
})
|
||||
|
||||
|
||||
assert_dict_has_keys_data_1 = [({
|
||||
'res_layer': 1,
|
||||
'norm_layer': 2,
|
||||
'dense_layer': 3
|
||||
})]
|
||||
assert_dict_has_keys_data_2 = [(['res_layer', 'dense_layer'], True),
|
||||
(['res_layer', 'conv_layer'], False)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('obj', assert_dict_has_keys_data_1)
|
||||
@pytest.mark.parametrize('expected_keys, ret_value',
|
||||
assert_dict_has_keys_data_2)
|
||||
def test_assert_dict_has_keys(obj, expected_keys, ret_value):
|
||||
assert testing.assert_dict_has_keys(obj, expected_keys) == ret_value
|
||||
|
||||
|
||||
assert_keys_equal_data_1 = [(['res_layer', 'norm_layer', 'dense_layer'])]
|
||||
assert_keys_equal_data_2 = [(['res_layer', 'norm_layer', 'dense_layer'], True),
|
||||
(['res_layer', 'dense_layer', 'norm_layer'], True),
|
||||
(['res_layer', 'norm_layer'], False),
|
||||
(['res_layer', 'conv_layer', 'norm_layer'], False)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('result_keys', assert_keys_equal_data_1)
|
||||
@pytest.mark.parametrize('target_keys, ret_value', assert_keys_equal_data_2)
|
||||
def test_assert_keys_equal(result_keys, target_keys, ret_value):
|
||||
assert testing.assert_keys_equal(result_keys, target_keys) == ret_value
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch is None, reason='requires torch library')
|
||||
def test_assert_is_norm_layer():
|
||||
# case 1
|
||||
assert not testing.assert_is_norm_layer(nn.Conv3d(3, 64, 3))
|
||||
|
||||
# case 2
|
||||
assert testing.assert_is_norm_layer(nn.BatchNorm3d(128))
|
||||
|
||||
# case 3
|
||||
assert testing.assert_is_norm_layer(nn.GroupNorm(8, 64))
|
||||
|
||||
# case 4
|
||||
assert not testing.assert_is_norm_layer(nn.Sigmoid())
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch is None, reason='requires torch library')
|
||||
def test_assert_params_all_zeros():
|
||||
demo_module = nn.Conv2d(3, 64, 3)
|
||||
nn.init.constant_(demo_module.weight, 0)
|
||||
nn.init.constant_(demo_module.bias, 0)
|
||||
assert testing.assert_params_all_zeros(demo_module)
|
||||
|
||||
nn.init.xavier_normal_(demo_module.weight)
|
||||
nn.init.constant_(demo_module.bias, 0)
|
||||
assert not testing.assert_params_all_zeros(demo_module)
|
||||
|
||||
demo_module = nn.Linear(2048, 400, bias=False)
|
||||
nn.init.constant_(demo_module.weight, 0)
|
||||
assert testing.assert_params_all_zeros(demo_module)
|
||||
|
||||
nn.init.normal_(demo_module.weight, mean=0, std=0.01)
|
||||
assert not testing.assert_params_all_zeros(demo_module)
|
||||
|
||||
|
||||
def test_check_python_script(capsys):
|
||||
testing.check_python_script('./tests/data/scripts/hello.py zz')
|
||||
captured = capsys.readouterr().out
|
||||
assert captured == 'hello zz!\n'
|
||||
testing.check_python_script('./tests/data/scripts/hello.py agent')
|
||||
captured = capsys.readouterr().out
|
||||
assert captured == 'hello agent!\n'
|
||||
# Make sure that wrong cmd raises an error
|
||||
with pytest.raises(SystemExit):
|
||||
testing.check_python_script('./tests/data/scripts/hello.py li zz')
|
Loading…
x
Reference in New Issue
Block a user