diff --git a/docs/en/api/utils.rst b/docs/en/api/utils.rst index 92466352..681e15d2 100644 --- a/docs/en/api/utils.rst +++ b/docs/en/api/utils.rst @@ -109,6 +109,7 @@ Miscellaneous to_ntuple check_prerequisites deprecated_api_warning + deprecated_function has_method is_method_overridden import_modules_from_strings diff --git a/docs/zh_cn/api/utils.rst b/docs/zh_cn/api/utils.rst index 92466352..681e15d2 100644 --- a/docs/zh_cn/api/utils.rst +++ b/docs/zh_cn/api/utils.rst @@ -109,6 +109,7 @@ Miscellaneous to_ntuple check_prerequisites deprecated_api_warning + deprecated_function has_method is_method_overridden import_modules_from_strings diff --git a/mmengine/utils/misc.py b/mmengine/utils/misc.py index cdb24d8c..d7fa1fb3 100644 --- a/mmengine/utils/misc.py +++ b/mmengine/utils/misc.py @@ -2,7 +2,10 @@ import collections.abc import functools import itertools +import logging +import re import subprocess +import textwrap import warnings from collections import abc from importlib import import_module @@ -387,3 +390,72 @@ def has_method(obj: object, method: str) -> bool: bool: True if the object has the method else False. """ return hasattr(obj, method) and callable(getattr(obj, method)) + + +def deprecated_function(since: str, removed_in: str, + instructions: str) -> Callable: + """Marks functions as deprecated. + + Throw a warning when a deprecated function is called, and add a note in the + docstring. Modified from https://github.com/pytorch/pytorch/blob/master/torch/onnx/_deprecation.py + + Args: + since (str): The version when the function was first deprecated. + removed_in (str): The version when the function will be removed. + instructions (str): The action users should take. + + Returns: + Callable: A new function, which will be deprecated soon. + """ # noqa: E501 + from mmengine import print_log + + def decorator(function): + + @functools.wraps(function) + def wrapper(*args, **kwargs): + print_log( + f"'{function.__module__}.{function.__name__}' " + f'is deprecated in version {since} and will be ' + f'removed in version {removed_in}. Please {instructions}.', + logger='current', + level=logging.WARNING, + ) + return function(*args, **kwargs) + + indent = ' ' + # Add a deprecation note to the docstring. + docstring = function.__doc__ or '' + # Add a note to the docstring. + deprecation_note = textwrap.dedent(f"""\ + .. deprecated:: {since} + Deprecated and will be removed in version {removed_in}. + Please {instructions}. + """) + # Split docstring at first occurrence of newline + pattern = '\n\n' + summary_and_body = re.split(pattern, docstring, 1) + + if len(summary_and_body) > 1: + summary, body = summary_and_body + body = textwrap.indent(textwrap.dedent(body), indent) + summary = '\n'.join( + [textwrap.dedent(string) for string in summary.split('\n')]) + summary = textwrap.indent(summary, prefix=indent) + # Dedent the body. We cannot do this with the presence of the + # summary because the body contains leading whitespaces when the + # summary does not. + new_docstring_parts = [ + deprecation_note, '\n\n', summary, '\n\n', body + ] + else: + summary = summary_and_body[0] + summary = '\n'.join( + [textwrap.dedent(string) for string in summary.split('\n')]) + summary = textwrap.indent(summary, prefix=indent) + new_docstring_parts = [deprecation_note, '\n\n', summary] + + wrapper.__doc__ = ''.join(new_docstring_parts) + + return wrapper + + return decorator diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py new file mode 100644 index 00000000..95d7a006 --- /dev/null +++ b/tests/test_utils/test_misc.py @@ -0,0 +1,285 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest + +from mmengine import MMLogger +# yapf: disable +from mmengine.utils.misc import (concat_list, deprecated_api_warning, + deprecated_function, has_method, + import_modules_from_strings, is_list_of, + is_method_overridden, is_seq_of, is_tuple_of, + iter_cast, list_cast, requires_executable, + requires_package, slice_list, to_1tuple, + to_2tuple, to_3tuple, to_4tuple, to_ntuple, + tuple_cast) + +# yapf: enable + + +def test_to_ntuple(): + single_number = 2 + assert to_1tuple(single_number) == (single_number, ) + assert to_2tuple(single_number) == (single_number, single_number) + assert to_3tuple(single_number) == (single_number, single_number, + single_number) + assert to_4tuple(single_number) == (single_number, single_number, + single_number, single_number) + assert to_ntuple(5)(single_number) == (single_number, single_number, + single_number, single_number, + single_number) + assert to_ntuple(6)(single_number) == (single_number, single_number, + single_number, single_number, + single_number, single_number) + + +def test_iter_cast(): + assert list_cast([1, 2, 3], int) == [1, 2, 3] + assert list_cast(['1.1', 2, '3'], float) == [1.1, 2.0, 3.0] + assert list_cast([1, 2, 3], str) == ['1', '2', '3'] + assert tuple_cast((1, 2, 3), str) == ('1', '2', '3') + assert next(iter_cast([1, 2, 3], str)) == '1' + with pytest.raises(TypeError): + iter_cast([1, 2, 3], '') + with pytest.raises(TypeError): + iter_cast(1, str) + + +def test_is_seq_of(): + assert is_seq_of([1.0, 2.0, 3.0], float) + assert is_seq_of([(1, ), (2, ), (3, )], tuple) + assert is_seq_of((1.0, 2.0, 3.0), float) + assert is_list_of([1.0, 2.0, 3.0], float) + assert not is_seq_of((1.0, 2.0, 3.0), float, seq_type=list) + assert not is_tuple_of([1.0, 2.0, 3.0], float) + assert not is_seq_of([1.0, 2, 3], int) + assert not is_seq_of((1.0, 2, 3), int) + + +def test_slice_list(): + in_list = [1, 2, 3, 4, 5, 6] + assert slice_list(in_list, [1, 2, 3]) == [[1], [2, 3], [4, 5, 6]] + assert slice_list(in_list, [len(in_list)]) == [in_list] + with pytest.raises(TypeError): + slice_list(in_list, 2.0) + with pytest.raises(ValueError): + slice_list(in_list, [1, 2]) + + +def test_concat_list(): + assert concat_list([[1, 2]]) == [1, 2] + assert concat_list([[1, 2], [3, 4, 5], [6]]) == [1, 2, 3, 4, 5, 6] + + +def test_requires_package(capsys): + + @requires_package('nnn') + def func_a(): + pass + + @requires_package(['numpy', 'n1', 'n2']) + def func_b(): + pass + + @requires_package('numpy') + def func_c(): + return 1 + + with pytest.raises(RuntimeError): + func_a() + out, _ = capsys.readouterr() + assert out == ('Prerequisites "nnn" are required in method "func_a" but ' + 'not found, please install them first.\n') + + with pytest.raises(RuntimeError): + func_b() + out, _ = capsys.readouterr() + assert out == ( + 'Prerequisites "n1, n2" are required in method "func_b" but not found,' + ' please install them first.\n') + + assert func_c() == 1 + + +def test_requires_executable(capsys): + + @requires_executable('nnn') + def func_a(): + pass + + @requires_executable(['ls', 'n1', 'n2']) + def func_b(): + pass + + @requires_executable('mv') + def func_c(): + return 1 + + with pytest.raises(RuntimeError): + func_a() + out, _ = capsys.readouterr() + assert out == ('Prerequisites "nnn" are required in method "func_a" but ' + 'not found, please install them first.\n') + + with pytest.raises(RuntimeError): + func_b() + out, _ = capsys.readouterr() + assert out == ( + 'Prerequisites "n1, n2" are required in method "func_b" but not found,' + ' please install them first.\n') + + assert func_c() == 1 + + +def test_import_modules_from_strings(): + # multiple imports + import os.path as osp_ + import sys as sys_ + osp, sys = import_modules_from_strings(['os.path', 'sys']) + assert osp == osp_ + assert sys == sys_ + + # single imports + osp = import_modules_from_strings('os.path') + assert osp == osp_ + # No imports + assert import_modules_from_strings(None) is None + assert import_modules_from_strings([]) is None + assert import_modules_from_strings('') is None + # Unsupported types + with pytest.raises(TypeError): + import_modules_from_strings(1) + with pytest.raises(TypeError): + import_modules_from_strings([1]) + # Failed imports + with pytest.raises(ImportError): + import_modules_from_strings('_not_implemented_module') + with pytest.warns(UserWarning): + imported = import_modules_from_strings( + '_not_implemented_module', allow_failed_imports=True) + assert imported is None + with pytest.warns(UserWarning): + imported = import_modules_from_strings(['os.path', '_not_implemented'], + allow_failed_imports=True) + assert imported[0] == osp + assert imported[1] is None + + +def test_is_method_overridden(): + + class Base: + + def foo1(): + pass + + def foo2(): + pass + + class Sub(Base): + + def foo1(): + pass + + # test passing sub class directly + assert is_method_overridden('foo1', Base, Sub) + assert not is_method_overridden('foo2', Base, Sub) + + # test passing instance of sub class + sub_instance = Sub() + assert is_method_overridden('foo1', Base, sub_instance) + assert not is_method_overridden('foo2', Base, sub_instance) + + # base_class should be a class, not instance + base_instance = Base() + with pytest.raises(AssertionError): + is_method_overridden('foo1', base_instance, sub_instance) + + +def test_has_method(): + + class Foo: + + def __init__(self, name): + self.name = name + + def print_name(self): + print(self.name) + + foo = Foo('foo') + assert not has_method(foo, 'name') + assert has_method(foo, 'print_name') + + +def test_deprecated_api_warning(): + + @deprecated_api_warning(name_dict=dict(old_key='new_key')) + def dummy_func(new_key=1): + return new_key + + # replace `old_key` to `new_key` + assert dummy_func(old_key=2) == 2 + + # The expected behavior is to replace the + # deprecated key `old_key` to `new_key`, + # but got them in the arguments at the same time + with pytest.raises(AssertionError): + dummy_func(old_key=1, new_key=2) + + +def test_deprecated_function(): + + @deprecated_function('0.2.0', '0.3.0', 'toy instruction') + def deprecated_demo(arg1: int, arg2: int) -> tuple: + """This is a long summary. This is a long summary. This is a long + summary. This is a long summary. + + Args: + arg1 (int): Long description with a line break. Long description + with a line break. + arg2 (int): short description. + + Returns: + Long description without a line break. Long description without + a line break. + """ + + return arg1, arg2 + + MMLogger.get_instance('test_deprecated_function') + deprecated_demo(1, 2) + # out, _ = capsys.readouterr() + # assert "'test_misc.deprecated_demo' is deprecated" in out + assert (1, 2) == deprecated_demo(1, 2) + + expected_docstring = \ + """.. deprecated:: 0.2.0 + Deprecated and will be removed in version 0.3.0. + Please toy instruction. + + + This is a long summary. This is a long summary. This is a long + summary. This is a long summary. + + Args: + arg1 (int): Long description with a line break. Long description + with a line break. + arg2 (int): short description. + + Returns: + Long description without a line break. Long description without + a line break. + """ # noqa: E122 + assert expected_docstring.strip(' ') == deprecated_demo.__doc__ + MMLogger._instance_dict.clear() + + # Test with short summary without args. + @deprecated_function('0.2.0', '0.3.0', 'toy instruction') + def deprecated_demo1(): + """Short summary.""" + + expected_docstring = \ + """.. deprecated:: 0.2.0 + Deprecated and will be removed in version 0.3.0. + Please toy instruction. + + + Short summary.""" # noqa: E122 + assert expected_docstring.strip(' ') == deprecated_demo1.__doc__