Add common testing function of MM repos (#743)

* add testing function

add unittest for check_dict

add unittest for the function in testing

* polish docstring of testing.py

rename some function

* remove  in is_all_zeros

* modify the comment of check_dict

* modify the testing.py according to feedback

* add test about numpy for function dict_contains_subset

* applying unified style
pull/795/head
congee 2021-01-08 13:18:08 +08:00 committed by GitHub
parent daab369e99
commit 905c9b43b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 313 additions and 2 deletions

View File

@ -9,6 +9,9 @@ from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
scandir, symlink)
from .progressbar import (ProgressBar, track_iter_progress,
track_parallel_progress, track_progress)
from .testing import (assert_attrs_equal, assert_dict_contains_subset,
assert_dict_has_keys, assert_is_norm_layer,
assert_keys_equal, assert_params_all_zeros)
from .timer import Timer, TimerError, check_time
from .version_utils import digit_version, get_git_hash
@ -23,7 +26,9 @@ except ImportError:
'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar',
'track_progress', 'track_iter_progress', 'track_parallel_progress',
'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
'digit_version', 'get_git_hash', 'import_modules_from_strings'
'digit_version', 'get_git_hash', 'import_modules_from_strings',
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal'
]
else:
from .env import collect_env
@ -49,5 +54,8 @@ else:
'_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension',
'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader',
'TORCH_VERSION', 'deprecated_api_warning', 'digit_version',
'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena'
'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena',
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
'assert_params_all_zeros'
]

View File

@ -0,0 +1,121 @@
# Copyright (c) Open-MMLab.
from collections.abc import Iterable
from typing import Any, Dict, List
def _any(judge_result):
"""Since built-in ``any`` works only when the element of iterable is not
iterable, implement the function."""
if not isinstance(judge_result, Iterable):
return judge_result
try:
for element in judge_result:
if _any(element):
return True
except TypeError:
# Maybe encouter the case: torch.tensor(True) | torch.tensor(False)
if judge_result:
return True
return False
def assert_dict_contains_subset(dict_obj: Dict[Any, Any],
expected_subset: Dict[Any, Any]) -> bool:
"""Check if the dict_obj contains the expected_subset.
Args:
dict_obj (Dict[Any, Any]): Dict object to be checked.
expected_subset (Dict[Any, Any]): Subset expected to be contained in
dict_obj.
Returns:
bool: Whether the dict_obj contains the expected_subset.
"""
for key, value in expected_subset.items():
if key not in dict_obj.keys() or _any(dict_obj[key] != value):
return False
return True
def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool:
"""Check if attribute of class object is correct.
Args:
obj (object): Class object to be checked.
expected_attrs (Dict[str, Any]): Dict of the expected attrs.
Returns:
bool: Whether the attribute of class object is correct.
"""
for attr, value in expected_attrs.items():
if not hasattr(obj, attr) or _any(getattr(obj, attr) != value):
return False
return True
def assert_dict_has_keys(obj: Dict[str, Any],
expected_keys: List[str]) -> bool:
"""Check if the obj has all the expected_keys.
Args:
obj (Dict[str, Any]): Object to be checked.
expected_keys (List[str]): Keys expected to contained in the keys of
the obj.
Returns:
bool: Whether the obj has the expected keys.
"""
return set(expected_keys).issubset(set(obj.keys()))
def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool:
"""Check if target_keys is equal to result_keys.
Args:
result_keys (List[str]): Result keys to be checked.
target_keys (List[str]): Target keys to be checked.
Returns:
bool: Whether target_keys is equal to result_keys.
"""
return set(result_keys) == set(target_keys)
def assert_is_norm_layer(module) -> bool:
"""Check if the module is a norm layer.
Args:
module (nn.Module): The module to be checked.
Returns:
bool: Whether the module is a norm layer.
"""
from .parrots_wrapper import _BatchNorm, _InstanceNorm
from torch.nn import GroupNorm, LayerNorm
norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)
return isinstance(module, norm_layer_candidates)
def assert_params_all_zeros(module) -> bool:
"""Check if the parameters of the module is all zeros.
Args:
module (nn.Module): The module to be checked.
Returns:
bool: Whether the parameters of the module is all zeros.
"""
weight_data = module.weight.data
is_weight_zero = weight_data.allclose(
weight_data.new_zeros(weight_data.size()))
if hasattr(module, 'bias') and module.bias is not None:
bias_data = module.bias.data
is_bias_zero = bias_data.allclose(
bias_data.new_zeros(bias_data.size()))
else:
is_bias_zero = True
return is_weight_zero and is_bias_zero

View File

@ -0,0 +1,182 @@
import numpy as np
import pytest
import mmcv
try:
import torch
except ImportError:
torch = None
else:
import torch.nn as nn
def test_assert_dict_contains_subset():
dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6)}
# case 1
expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6)}
assert mmcv.assert_dict_contains_subset(dict_obj, expected_subset)
# case 2
expected_subset = {'a': 'test1', 'b': 2, 'c': (6, 4)}
assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset)
# case 3
expected_subset = {'a': 'test1', 'b': 2, 'c': None}
assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset)
# case 4
expected_subset = {'a': 'test1', 'b': 2, 'd': (4, 6)}
assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset)
# case 5
dict_obj = {
'a': 'test1',
'b': 2,
'c': (4, 6),
'd': np.array([[5, 3, 5], [1, 2, 3]])
}
expected_subset = {
'a': 'test1',
'b': 2,
'c': (4, 6),
'd': np.array([[5, 3, 5], [6, 2, 3]])
}
assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset)
# case 6
dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])}
expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])}
assert mmcv.assert_dict_contains_subset(dict_obj, expected_subset)
if torch is not None:
dict_obj = {
'a': 'test1',
'b': 2,
'c': (4, 6),
'd': torch.tensor([5, 3, 5])
}
# case 7
expected_subset = {'d': torch.tensor([5, 5, 5])}
assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset)
# case 8
expected_subset = {'d': torch.tensor([[5, 3, 5], [4, 1, 2]])}
assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset)
def test_assert_attrs_equal():
class TestExample(object):
a, b, c = 1, ('wvi', 3), [4.5, 3.14]
def test_func(self):
return self.b
# case 1
assert mmcv.assert_attrs_equal(TestExample, {
'a': 1,
'b': ('wvi', 3),
'c': [4.5, 3.14]
})
# case 2
assert not mmcv.assert_attrs_equal(TestExample, {
'a': 1,
'b': ('wvi', 3),
'c': [4.5, 3.14, 2]
})
# case 3
assert not mmcv.assert_attrs_equal(TestExample, {
'bc': 54,
'c': [4.5, 3.14]
})
# case 4
assert mmcv.assert_attrs_equal(TestExample, {
'b': ('wvi', 3),
'test_func': TestExample.test_func
})
if torch is not None:
class TestExample(object):
a, b = torch.tensor([1]), torch.tensor([4, 5])
# case 5
assert mmcv.assert_attrs_equal(TestExample, {
'a': torch.tensor([1]),
'b': torch.tensor([4, 5])
})
# case 6
assert not mmcv.assert_attrs_equal(TestExample, {
'a': torch.tensor([1]),
'b': torch.tensor([4, 6])
})
assert_dict_has_keys_data_1 = [({
'res_layer': 1,
'norm_layer': 2,
'dense_layer': 3
})]
assert_dict_has_keys_data_2 = [(['res_layer', 'dense_layer'], True),
(['res_layer', 'conv_layer'], False)]
@pytest.mark.parametrize('obj', assert_dict_has_keys_data_1)
@pytest.mark.parametrize('expected_keys, ret_value',
assert_dict_has_keys_data_2)
def test_assert_dict_has_keys(obj, expected_keys, ret_value):
assert mmcv.assert_dict_has_keys(obj, expected_keys) == ret_value
assert_keys_equal_data_1 = [(['res_layer', 'norm_layer', 'dense_layer'])]
assert_keys_equal_data_2 = [(['res_layer', 'norm_layer', 'dense_layer'], True),
(['res_layer', 'dense_layer', 'norm_layer'], True),
(['res_layer', 'norm_layer'], False),
(['res_layer', 'conv_layer', 'norm_layer'], False)]
@pytest.mark.parametrize('result_keys', assert_keys_equal_data_1)
@pytest.mark.parametrize('target_keys, ret_value', assert_keys_equal_data_2)
def test_assert_keys_equal(result_keys, target_keys, ret_value):
assert mmcv.assert_keys_equal(result_keys, target_keys) == ret_value
@pytest.mark.skipif(torch is None, reason='requires torch library')
def test_assert_is_norm_layer():
# case 1
assert not mmcv.assert_is_norm_layer(nn.Conv3d(3, 64, 3))
# case 2
assert mmcv.assert_is_norm_layer(nn.BatchNorm3d(128))
# case 3
assert mmcv.assert_is_norm_layer(nn.GroupNorm(8, 64))
# case 4
assert not mmcv.assert_is_norm_layer(nn.Sigmoid())
@pytest.mark.skipif(torch is None, reason='requires torch library')
def test_assert_params_all_zeros():
demo_module = nn.Conv2d(3, 64, 3)
nn.init.constant_(demo_module.weight, 0)
nn.init.constant_(demo_module.bias, 0)
assert mmcv.assert_params_all_zeros(demo_module)
nn.init.xavier_normal_(demo_module.weight)
nn.init.constant_(demo_module.bias, 0)
assert not mmcv.assert_params_all_zeros(demo_module)
demo_module = nn.Linear(2048, 400, bias=False)
nn.init.constant_(demo_module.weight, 0)
assert mmcv.assert_params_all_zeros(demo_module)
nn.init.normal_(demo_module.weight, mean=0, std=0.01)
assert not mmcv.assert_params_all_zeros(demo_module)