mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
Add test util for checking stand-alone python scripts (#1007)
* Add test util for checking stand-alone scripts Signed-off-by: lizz <lizz@sensetime.com> * Restrict to python scripts Signed-off-by: lizz <lizz@sensetime.com> * fix Signed-off-by: lizz <lizz@sensetime.com> * tiny Signed-off-by: lizz <lizz@sensetime.com> * Allow no capture Signed-off-by: lizz <lizz@sensetime.com> * Simplify interface Signed-off-by: lizz <lizz@sensetime.com> * Technical notes Signed-off-by: lizz <lizz@sensetime.com> * tiny Signed-off-by: lizz <lizz@sensetime.com> * Update hello.py * Update test_testing.py * Update test_testing.py
This commit is contained in:
parent
934b549e23
commit
f61295d944
@ -11,7 +11,8 @@ from .progressbar import (ProgressBar, track_iter_progress,
|
|||||||
track_parallel_progress, track_progress)
|
track_parallel_progress, track_progress)
|
||||||
from .testing import (assert_attrs_equal, assert_dict_contains_subset,
|
from .testing import (assert_attrs_equal, assert_dict_contains_subset,
|
||||||
assert_dict_has_keys, assert_is_norm_layer,
|
assert_dict_has_keys, assert_is_norm_layer,
|
||||||
assert_keys_equal, assert_params_all_zeros)
|
assert_keys_equal, assert_params_all_zeros,
|
||||||
|
check_python_script)
|
||||||
from .timer import Timer, TimerError, check_time
|
from .timer import Timer, TimerError, check_time
|
||||||
from .version_utils import digit_version, get_git_hash
|
from .version_utils import digit_version, get_git_hash
|
||||||
|
|
||||||
@ -28,7 +29,7 @@ except ImportError:
|
|||||||
'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
|
'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
|
||||||
'digit_version', 'get_git_hash', 'import_modules_from_strings',
|
'digit_version', 'get_git_hash', 'import_modules_from_strings',
|
||||||
'assert_dict_contains_subset', 'assert_attrs_equal',
|
'assert_dict_contains_subset', 'assert_attrs_equal',
|
||||||
'assert_dict_has_keys', 'assert_keys_equal'
|
'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script'
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
from .env import collect_env
|
from .env import collect_env
|
||||||
@ -57,5 +58,5 @@ else:
|
|||||||
'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena',
|
'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena',
|
||||||
'assert_dict_contains_subset', 'assert_attrs_equal',
|
'assert_dict_contains_subset', 'assert_attrs_equal',
|
||||||
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
|
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
|
||||||
'assert_params_all_zeros'
|
'assert_params_all_zeros', 'check_python_script'
|
||||||
]
|
]
|
||||||
|
@ -1,6 +1,25 @@
|
|||||||
# Copyright (c) Open-MMLab.
|
# Copyright (c) Open-MMLab.
|
||||||
|
import sys
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from runpy import run_path
|
||||||
|
from shlex import split
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
|
||||||
|
def check_python_script(cmd):
|
||||||
|
"""Run the python cmd script with `__main__`. The difference between
|
||||||
|
`os.system` is that, this function exectues code in the current process, so
|
||||||
|
that it can be tracked by coverage tools. Currently it supports two forms:
|
||||||
|
|
||||||
|
- ./tests/data/scripts/hello.py zz
|
||||||
|
- python tests/data/scripts/hello.py zz
|
||||||
|
"""
|
||||||
|
args = split(cmd)
|
||||||
|
if args[0] == 'python':
|
||||||
|
args = args[1:]
|
||||||
|
with patch.object(sys, 'argv', args):
|
||||||
|
run_path(args[0], run_name='__main__')
|
||||||
|
|
||||||
|
|
||||||
def _any(judge_result):
|
def _any(judge_result):
|
||||||
|
24
tests/data/scripts/hello.py
Executable file
24
tests/data/scripts/hello.py
Executable file
@ -0,0 +1,24 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='Say hello.')
|
||||||
|
parser.add_argument('name', help='To whom.')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
print(f'hello {args.name}!')
|
||||||
|
if args.name == 'agent':
|
||||||
|
warnings.warn('I have a secret!')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -180,3 +180,15 @@ def test_assert_params_all_zeros():
|
|||||||
|
|
||||||
nn.init.normal_(demo_module.weight, mean=0, std=0.01)
|
nn.init.normal_(demo_module.weight, mean=0, std=0.01)
|
||||||
assert not mmcv.assert_params_all_zeros(demo_module)
|
assert not mmcv.assert_params_all_zeros(demo_module)
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_python_script(capsys):
|
||||||
|
mmcv.utils.check_python_script('./tests/data/scripts/hello.py zz')
|
||||||
|
captured = capsys.readouterr().out
|
||||||
|
assert captured == 'hello zz!\n'
|
||||||
|
mmcv.utils.check_python_script('./tests/data/scripts/hello.py agent')
|
||||||
|
captured = capsys.readouterr().out
|
||||||
|
assert captured == 'hello agent!\n'
|
||||||
|
# Make sure that wrong cmd raises an error
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
mmcv.utils.check_python_script('./tests/data/scripts/hello.py li zz')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user