# Copyright (c) OpenMMLab. All rights reserved. import pytest from mmengine import MMLogger # yapf: disable from mmengine.utils.misc import (concat_list, deprecated_api_warning, deprecated_function, has_method, import_modules_from_strings, is_list_of, is_method_overridden, is_seq_of, is_tuple_of, iter_cast, list_cast, requires_executable, requires_package, slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple, to_ntuple, tuple_cast) # yapf: enable def test_to_ntuple(): single_number = 2 assert to_1tuple(single_number) == (single_number, ) assert to_2tuple(single_number) == (single_number, single_number) assert to_3tuple(single_number) == (single_number, single_number, single_number) assert to_4tuple(single_number) == (single_number, single_number, single_number, single_number) assert to_ntuple(5)(single_number) == (single_number, single_number, single_number, single_number, single_number) assert to_ntuple(6)(single_number) == (single_number, single_number, single_number, single_number, single_number, single_number) def test_iter_cast(): assert list_cast([1, 2, 3], int) == [1, 2, 3] assert list_cast(['1.1', 2, '3'], float) == [1.1, 2.0, 3.0] assert list_cast([1, 2, 3], str) == ['1', '2', '3'] assert tuple_cast((1, 2, 3), str) == ('1', '2', '3') assert next(iter_cast([1, 2, 3], str)) == '1' with pytest.raises(TypeError): iter_cast([1, 2, 3], '') with pytest.raises(TypeError): iter_cast(1, str) def test_is_seq_of(): assert is_seq_of([1.0, 2.0, 3.0], float) assert is_seq_of([(1, ), (2, ), (3, )], tuple) assert is_seq_of((1.0, 2.0, 3.0), float) assert is_list_of([1.0, 2.0, 3.0], float) assert not is_seq_of((1.0, 2.0, 3.0), float, seq_type=list) assert not is_tuple_of([1.0, 2.0, 3.0], float) assert not is_seq_of([1.0, 2, 3], int) assert not is_seq_of((1.0, 2, 3), int) def test_slice_list(): in_list = [1, 2, 3, 4, 5, 6] assert slice_list(in_list, [1, 2, 3]) == [[1], [2, 3], [4, 5, 6]] assert slice_list(in_list, [len(in_list)]) == [in_list] with pytest.raises(TypeError): slice_list(in_list, 2.0) with pytest.raises(ValueError): slice_list(in_list, [1, 2]) def test_concat_list(): assert concat_list([[1, 2]]) == [1, 2] assert concat_list([[1, 2], [3, 4, 5], [6]]) == [1, 2, 3, 4, 5, 6] def test_requires_package(capsys): @requires_package('nnn') def func_a(): pass @requires_package(['numpy', 'n1', 'n2']) def func_b(): pass @requires_package('numpy') def func_c(): return 1 with pytest.raises(RuntimeError): func_a() out, _ = capsys.readouterr() assert out == ('Prerequisites "nnn" are required in method "func_a" but ' 'not found, please install them first.\n') with pytest.raises(RuntimeError): func_b() out, _ = capsys.readouterr() assert out == ( 'Prerequisites "n1, n2" are required in method "func_b" but not found,' ' please install them first.\n') assert func_c() == 1 def test_requires_executable(capsys): @requires_executable('nnn') def func_a(): pass @requires_executable(['ls', 'n1', 'n2']) def func_b(): pass @requires_executable('mv') def func_c(): return 1 with pytest.raises(RuntimeError): func_a() out, _ = capsys.readouterr() assert out == ('Prerequisites "nnn" are required in method "func_a" but ' 'not found, please install them first.\n') with pytest.raises(RuntimeError): func_b() out, _ = capsys.readouterr() assert out == ( 'Prerequisites "n1, n2" are required in method "func_b" but not found,' ' please install them first.\n') assert func_c() == 1 def test_import_modules_from_strings(): # multiple imports import os.path as osp_ import sys as sys_ osp, sys = import_modules_from_strings(['os.path', 'sys']) assert osp == osp_ assert sys == sys_ # single imports osp = import_modules_from_strings('os.path') assert osp == osp_ # No imports assert import_modules_from_strings(None) is None assert import_modules_from_strings([]) is None assert import_modules_from_strings('') is None # Unsupported types with pytest.raises(TypeError): import_modules_from_strings(1) with pytest.raises(TypeError): import_modules_from_strings([1]) # Failed imports with pytest.raises(ImportError): import_modules_from_strings('_not_implemented_module') with pytest.warns(UserWarning): imported = import_modules_from_strings( '_not_implemented_module', allow_failed_imports=True) assert imported is None with pytest.warns(UserWarning): imported = import_modules_from_strings(['os.path', '_not_implemented'], allow_failed_imports=True) assert imported[0] == osp assert imported[1] is None def test_is_method_overridden(): class Base: def foo1(): pass def foo2(): pass class Sub(Base): def foo1(): pass # test passing sub class directly assert is_method_overridden('foo1', Base, Sub) assert not is_method_overridden('foo2', Base, Sub) # test passing instance of sub class sub_instance = Sub() assert is_method_overridden('foo1', Base, sub_instance) assert not is_method_overridden('foo2', Base, sub_instance) # base_class should be a class, not instance base_instance = Base() with pytest.raises(AssertionError): is_method_overridden('foo1', base_instance, sub_instance) def test_has_method(): class Foo: def __init__(self, name): self.name = name def print_name(self): print(self.name) foo = Foo('foo') assert not has_method(foo, 'name') assert has_method(foo, 'print_name') def test_deprecated_api_warning(): @deprecated_api_warning(name_dict=dict(old_key='new_key')) def dummy_func(new_key=1): return new_key # replace `old_key` to `new_key` assert dummy_func(old_key=2) == 2 # The expected behavior is to replace the # deprecated key `old_key` to `new_key`, # but got them in the arguments at the same time with pytest.raises(AssertionError): dummy_func(old_key=1, new_key=2) def test_deprecated_function(): @deprecated_function('0.2.0', '0.3.0', 'toy instruction') def deprecated_demo(arg1: int, arg2: int) -> tuple: """This is a long summary. This is a long summary. This is a long summary. This is a long summary. Args: arg1 (int): Long description with a line break. Long description with a line break. arg2 (int): short description. Returns: Long description without a line break. Long description without a line break. """ return arg1, arg2 MMLogger.get_instance('test_deprecated_function') deprecated_demo(1, 2) # out, _ = capsys.readouterr() # assert "'test_misc.deprecated_demo' is deprecated" in out assert (1, 2) == deprecated_demo(1, 2) expected_docstring = \ """.. deprecated:: 0.2.0 Deprecated and will be removed in version 0.3.0. Please toy instruction. This is a long summary. This is a long summary. This is a long summary. This is a long summary. Args: arg1 (int): Long description with a line break. Long description with a line break. arg2 (int): short description. Returns: Long description without a line break. Long description without a line break. """ # noqa: E122 assert expected_docstring.strip(' ') == deprecated_demo.__doc__ MMLogger._instance_dict.clear() # Test with short summary without args. @deprecated_function('0.2.0', '0.3.0', 'toy instruction') def deprecated_demo1(): """Short summary.""" expected_docstring = \ """.. deprecated:: 0.2.0 Deprecated and will be removed in version 0.3.0. Please toy instruction. Short summary.""" # noqa: E122 assert expected_docstring.strip(' ') == deprecated_demo1.__doc__