diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index 4c10f888d..ffbbe4b29 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -9,6 +9,9 @@ from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist, scandir, symlink) from .progressbar import (ProgressBar, track_iter_progress, track_parallel_progress, track_progress) +from .testing import (assert_attrs_equal, assert_dict_contains_subset, + assert_dict_has_keys, assert_is_norm_layer, + assert_keys_equal, assert_params_all_zeros) from .timer import Timer, TimerError, check_time from .version_utils import digit_version, get_git_hash @@ -23,7 +26,9 @@ except ImportError: 'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar', 'track_progress', 'track_iter_progress', 'track_parallel_progress', 'Timer', 'TimerError', 'check_time', 'deprecated_api_warning', - 'digit_version', 'get_git_hash', 'import_modules_from_strings' + 'digit_version', 'get_git_hash', 'import_modules_from_strings', + 'assert_dict_contains_subset', 'assert_attrs_equal', + 'assert_dict_has_keys', 'assert_keys_equal' ] else: from .env import collect_env @@ -49,5 +54,8 @@ else: '_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader', 'TORCH_VERSION', 'deprecated_api_warning', 'digit_version', - 'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena' + 'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena', + 'assert_dict_contains_subset', 'assert_attrs_equal', + 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', + 'assert_params_all_zeros' ] diff --git a/mmcv/utils/testing.py b/mmcv/utils/testing.py new file mode 100644 index 000000000..063e15987 --- /dev/null +++ b/mmcv/utils/testing.py @@ -0,0 +1,121 @@ +# Copyright (c) Open-MMLab. +from collections.abc import Iterable +from typing import Any, Dict, List + + +def _any(judge_result): + """Since built-in ``any`` works only when the element of iterable is not + iterable, implement the function.""" + if not isinstance(judge_result, Iterable): + return judge_result + + try: + for element in judge_result: + if _any(element): + return True + except TypeError: + # Maybe encouter the case: torch.tensor(True) | torch.tensor(False) + if judge_result: + return True + return False + + +def assert_dict_contains_subset(dict_obj: Dict[Any, Any], + expected_subset: Dict[Any, Any]) -> bool: + """Check if the dict_obj contains the expected_subset. + + Args: + dict_obj (Dict[Any, Any]): Dict object to be checked. + expected_subset (Dict[Any, Any]): Subset expected to be contained in + dict_obj. + + Returns: + bool: Whether the dict_obj contains the expected_subset. + """ + + for key, value in expected_subset.items(): + if key not in dict_obj.keys() or _any(dict_obj[key] != value): + return False + return True + + +def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool: + """Check if attribute of class object is correct. + + Args: + obj (object): Class object to be checked. + expected_attrs (Dict[str, Any]): Dict of the expected attrs. + + Returns: + bool: Whether the attribute of class object is correct. + """ + for attr, value in expected_attrs.items(): + if not hasattr(obj, attr) or _any(getattr(obj, attr) != value): + return False + return True + + +def assert_dict_has_keys(obj: Dict[str, Any], + expected_keys: List[str]) -> bool: + """Check if the obj has all the expected_keys. + + Args: + obj (Dict[str, Any]): Object to be checked. + expected_keys (List[str]): Keys expected to contained in the keys of + the obj. + + Returns: + bool: Whether the obj has the expected keys. + """ + return set(expected_keys).issubset(set(obj.keys())) + + +def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool: + """Check if target_keys is equal to result_keys. + + Args: + result_keys (List[str]): Result keys to be checked. + target_keys (List[str]): Target keys to be checked. + + Returns: + bool: Whether target_keys is equal to result_keys. + """ + return set(result_keys) == set(target_keys) + + +def assert_is_norm_layer(module) -> bool: + """Check if the module is a norm layer. + + Args: + module (nn.Module): The module to be checked. + + Returns: + bool: Whether the module is a norm layer. + """ + from .parrots_wrapper import _BatchNorm, _InstanceNorm + from torch.nn import GroupNorm, LayerNorm + norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm) + return isinstance(module, norm_layer_candidates) + + +def assert_params_all_zeros(module) -> bool: + """Check if the parameters of the module is all zeros. + + Args: + module (nn.Module): The module to be checked. + + Returns: + bool: Whether the parameters of the module is all zeros. + """ + weight_data = module.weight.data + is_weight_zero = weight_data.allclose( + weight_data.new_zeros(weight_data.size())) + + if hasattr(module, 'bias') and module.bias is not None: + bias_data = module.bias.data + is_bias_zero = bias_data.allclose( + bias_data.new_zeros(bias_data.size())) + else: + is_bias_zero = True + + return is_weight_zero and is_bias_zero diff --git a/tests/test_utils/test_testing.py b/tests/test_utils/test_testing.py new file mode 100644 index 000000000..dce71a303 --- /dev/null +++ b/tests/test_utils/test_testing.py @@ -0,0 +1,182 @@ +import numpy as np +import pytest + +import mmcv + +try: + import torch +except ImportError: + torch = None +else: + import torch.nn as nn + + +def test_assert_dict_contains_subset(): + dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6)} + + # case 1 + expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6)} + assert mmcv.assert_dict_contains_subset(dict_obj, expected_subset) + + # case 2 + expected_subset = {'a': 'test1', 'b': 2, 'c': (6, 4)} + assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) + + # case 3 + expected_subset = {'a': 'test1', 'b': 2, 'c': None} + assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) + + # case 4 + expected_subset = {'a': 'test1', 'b': 2, 'd': (4, 6)} + assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) + + # case 5 + dict_obj = { + 'a': 'test1', + 'b': 2, + 'c': (4, 6), + 'd': np.array([[5, 3, 5], [1, 2, 3]]) + } + expected_subset = { + 'a': 'test1', + 'b': 2, + 'c': (4, 6), + 'd': np.array([[5, 3, 5], [6, 2, 3]]) + } + assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) + + # case 6 + dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])} + expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])} + assert mmcv.assert_dict_contains_subset(dict_obj, expected_subset) + + if torch is not None: + dict_obj = { + 'a': 'test1', + 'b': 2, + 'c': (4, 6), + 'd': torch.tensor([5, 3, 5]) + } + + # case 7 + expected_subset = {'d': torch.tensor([5, 5, 5])} + assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) + + # case 8 + expected_subset = {'d': torch.tensor([[5, 3, 5], [4, 1, 2]])} + assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) + + +def test_assert_attrs_equal(): + + class TestExample(object): + a, b, c = 1, ('wvi', 3), [4.5, 3.14] + + def test_func(self): + return self.b + + # case 1 + assert mmcv.assert_attrs_equal(TestExample, { + 'a': 1, + 'b': ('wvi', 3), + 'c': [4.5, 3.14] + }) + + # case 2 + assert not mmcv.assert_attrs_equal(TestExample, { + 'a': 1, + 'b': ('wvi', 3), + 'c': [4.5, 3.14, 2] + }) + + # case 3 + assert not mmcv.assert_attrs_equal(TestExample, { + 'bc': 54, + 'c': [4.5, 3.14] + }) + + # case 4 + assert mmcv.assert_attrs_equal(TestExample, { + 'b': ('wvi', 3), + 'test_func': TestExample.test_func + }) + + if torch is not None: + + class TestExample(object): + a, b = torch.tensor([1]), torch.tensor([4, 5]) + + # case 5 + assert mmcv.assert_attrs_equal(TestExample, { + 'a': torch.tensor([1]), + 'b': torch.tensor([4, 5]) + }) + + # case 6 + assert not mmcv.assert_attrs_equal(TestExample, { + 'a': torch.tensor([1]), + 'b': torch.tensor([4, 6]) + }) + + +assert_dict_has_keys_data_1 = [({ + 'res_layer': 1, + 'norm_layer': 2, + 'dense_layer': 3 +})] +assert_dict_has_keys_data_2 = [(['res_layer', 'dense_layer'], True), + (['res_layer', 'conv_layer'], False)] + + +@pytest.mark.parametrize('obj', assert_dict_has_keys_data_1) +@pytest.mark.parametrize('expected_keys, ret_value', + assert_dict_has_keys_data_2) +def test_assert_dict_has_keys(obj, expected_keys, ret_value): + assert mmcv.assert_dict_has_keys(obj, expected_keys) == ret_value + + +assert_keys_equal_data_1 = [(['res_layer', 'norm_layer', 'dense_layer'])] +assert_keys_equal_data_2 = [(['res_layer', 'norm_layer', 'dense_layer'], True), + (['res_layer', 'dense_layer', 'norm_layer'], True), + (['res_layer', 'norm_layer'], False), + (['res_layer', 'conv_layer', 'norm_layer'], False)] + + +@pytest.mark.parametrize('result_keys', assert_keys_equal_data_1) +@pytest.mark.parametrize('target_keys, ret_value', assert_keys_equal_data_2) +def test_assert_keys_equal(result_keys, target_keys, ret_value): + assert mmcv.assert_keys_equal(result_keys, target_keys) == ret_value + + +@pytest.mark.skipif(torch is None, reason='requires torch library') +def test_assert_is_norm_layer(): + # case 1 + assert not mmcv.assert_is_norm_layer(nn.Conv3d(3, 64, 3)) + + # case 2 + assert mmcv.assert_is_norm_layer(nn.BatchNorm3d(128)) + + # case 3 + assert mmcv.assert_is_norm_layer(nn.GroupNorm(8, 64)) + + # case 4 + assert not mmcv.assert_is_norm_layer(nn.Sigmoid()) + + +@pytest.mark.skipif(torch is None, reason='requires torch library') +def test_assert_params_all_zeros(): + demo_module = nn.Conv2d(3, 64, 3) + nn.init.constant_(demo_module.weight, 0) + nn.init.constant_(demo_module.bias, 0) + assert mmcv.assert_params_all_zeros(demo_module) + + nn.init.xavier_normal_(demo_module.weight) + nn.init.constant_(demo_module.bias, 0) + assert not mmcv.assert_params_all_zeros(demo_module) + + demo_module = nn.Linear(2048, 400, bias=False) + nn.init.constant_(demo_module.weight, 0) + assert mmcv.assert_params_all_zeros(demo_module) + + nn.init.normal_(demo_module.weight, mean=0, std=0.01) + assert not mmcv.assert_params_all_zeros(demo_module)