mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add MultiProcessTestCase (#136)
* [Enhancement] Provide MultiProcessTestCase to test distributed related modules * remove debugging info * add timeout property
This commit is contained in:
parent
26f24296db
commit
2bf099d33c
4
mmengine/testing/_internal/__init__.py
Normal file
4
mmengine/testing/_internal/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .distributed import MultiProcessTestCase
|
||||
|
||||
__all__ = ['MultiProcessTestCase']
|
355
mmengine/testing/_internal/distributed.py
Normal file
355
mmengine/testing/_internal/distributed.py
Normal file
@ -0,0 +1,355 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Copyright (c) https://github.com/pytorch/pytorch
|
||||
# Modified from https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_distributed.py # noqa: E501
|
||||
|
||||
import faulthandler
|
||||
import logging
|
||||
import multiprocessing
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import types
|
||||
import unittest
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import NamedTuple
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing import active_children
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestSkip(NamedTuple):
|
||||
exit_code: int
|
||||
message: str
|
||||
|
||||
|
||||
TEST_SKIPS = {
|
||||
'backend_unavailable':
|
||||
TestSkip(10, 'Skipped because distributed backend is not available.'),
|
||||
'no_cuda':
|
||||
TestSkip(11, 'CUDA is not available.'),
|
||||
'multi-gpu-2':
|
||||
TestSkip(12, 'Need at least 2 CUDA device'),
|
||||
'generic':
|
||||
TestSkip(
|
||||
13, 'Test skipped at subprocess level, look at subprocess log for '
|
||||
'skip reason'),
|
||||
}
|
||||
|
||||
# [How does MultiProcessTestCase work?]
|
||||
# Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by
|
||||
# default `world_size()` returns 2. Let's take `test_rpc_spawn.py` as an
|
||||
# example which inherits from this class. Its `Setup()` methods calls into
|
||||
# `MultiProcessTestCase._spawn_processes()` which spawns `world_size()`
|
||||
# subprocesses. During the spawn, the main process passes the test name to
|
||||
# subprocesses, and the name is acquired from self.id(). The subprocesses
|
||||
# then use the provided test function name to retrieve the function attribute
|
||||
# from the test instance and run it. The main process simply waits for all
|
||||
# subprocesses to join.
|
||||
|
||||
|
||||
class MultiProcessTestCase(TestCase):
|
||||
MAIN_PROCESS_RANK = -1
|
||||
|
||||
# This exit code is used to indicate that the test code had an error and
|
||||
# exited abnormally. There are certain tests that might use sys.exit() to
|
||||
# simulate failures and in those cases, we can't have an exit code of 0,
|
||||
# but we still want to ensure we didn't run into any other errors.
|
||||
TEST_ERROR_EXIT_CODE = 10
|
||||
|
||||
# do not early terminate for distributed tests.
|
||||
def _should_stop_test_suite(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
return 500
|
||||
|
||||
def join_or_run(self, fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(self):
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
self._join_processes(fn)
|
||||
else:
|
||||
fn()
|
||||
|
||||
return types.MethodType(wrapper, self)
|
||||
|
||||
# The main process spawns N subprocesses that run the test.
|
||||
# Constructor patches current instance test method to
|
||||
# assume the role of the main process and join its subprocesses,
|
||||
# or run the underlying test function.
|
||||
def __init__(self, method_name: str = 'runTest') -> None:
|
||||
super().__init__(method_name)
|
||||
fn = getattr(self, method_name)
|
||||
setattr(self, method_name, self.join_or_run(fn))
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.skip_return_code_checks = [] # type: ignore[var-annotated]
|
||||
self.processes = [] # type: ignore[var-annotated]
|
||||
self.rank = self.MAIN_PROCESS_RANK
|
||||
self.file_name = tempfile.NamedTemporaryFile(delete=False).name
|
||||
# pid to pipe consisting of error message from process.
|
||||
self.pid_to_pipe = {} # type: ignore[var-annotated]
|
||||
|
||||
def tearDown(self) -> None:
|
||||
super().tearDown()
|
||||
for p in self.processes:
|
||||
p.terminate()
|
||||
# Each Process instance holds a few open file descriptors. The unittest
|
||||
# runner creates a new TestCase instance for each test method and keeps
|
||||
# it alive until the end of the entire suite. We must thus reset the
|
||||
# processes to prevent an effective file descriptor leak.
|
||||
self.processes = []
|
||||
|
||||
def _current_test_name(self) -> str:
|
||||
# self.id()
|
||||
# e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
|
||||
return self.id().split('.')[-1]
|
||||
|
||||
def _start_processes(self, proc) -> None:
|
||||
self.processes = []
|
||||
for rank in range(int(self.world_size)):
|
||||
parent_conn, child_conn = torch.multiprocessing.Pipe()
|
||||
process = proc(
|
||||
target=self.__class__._run,
|
||||
name='process ' + str(rank),
|
||||
args=(rank, self._current_test_name(), self.file_name,
|
||||
child_conn),
|
||||
)
|
||||
process.start()
|
||||
self.pid_to_pipe[process.pid] = parent_conn
|
||||
self.processes.append(process)
|
||||
|
||||
def _spawn_processes(self) -> None:
|
||||
proc = torch.multiprocessing.get_context('spawn').Process
|
||||
self._start_processes(proc)
|
||||
|
||||
class Event(Enum):
|
||||
GET_TRACEBACK = 1
|
||||
|
||||
@staticmethod
|
||||
def _event_listener(parent_pipe, signal_pipe, rank: int):
|
||||
while True:
|
||||
ready_pipes = multiprocessing.connection.wait(
|
||||
[parent_pipe, signal_pipe])
|
||||
|
||||
if parent_pipe in ready_pipes:
|
||||
|
||||
if parent_pipe.closed:
|
||||
return
|
||||
|
||||
event = parent_pipe.recv()
|
||||
|
||||
if event == MultiProcessTestCase.Event.GET_TRACEBACK:
|
||||
# Return traceback to the parent process.
|
||||
with tempfile.NamedTemporaryFile(mode='r+') as tmp_file:
|
||||
faulthandler.dump_traceback(tmp_file)
|
||||
# Flush buffers and seek to read from the beginning
|
||||
tmp_file.flush()
|
||||
tmp_file.seek(0)
|
||||
parent_pipe.send(tmp_file.read())
|
||||
|
||||
if signal_pipe in ready_pipes:
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def _run(cls, rank: int, test_name: str, file_name: str,
|
||||
parent_pipe) -> None:
|
||||
self = cls(test_name)
|
||||
|
||||
self.rank = rank
|
||||
self.file_name = file_name
|
||||
self.run_test(test_name, parent_pipe)
|
||||
|
||||
def run_test(self, test_name: str, parent_pipe) -> None:
|
||||
# Start event listener thread.
|
||||
signal_recv_pipe, signal_send_pipe = torch.multiprocessing.Pipe(
|
||||
duplex=False)
|
||||
event_listener_thread = threading.Thread(
|
||||
target=MultiProcessTestCase._event_listener,
|
||||
args=(parent_pipe, signal_recv_pipe, self.rank),
|
||||
daemon=True,
|
||||
)
|
||||
event_listener_thread.start()
|
||||
|
||||
# self.id() == e.g. '__main__.TestDistributed.test_get_rank'
|
||||
# We're retrieving a corresponding test and executing it.
|
||||
try:
|
||||
getattr(self, test_name)()
|
||||
except unittest.SkipTest as se:
|
||||
logger.info(f'Process {self.rank} skipping test {test_name} for '
|
||||
f'following reason: {str(se)}')
|
||||
sys.exit(TEST_SKIPS['generic'].exit_code)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f'Caught exception: \n{traceback.format_exc()} exiting '
|
||||
f'process {self.rank} with exit code: '
|
||||
f'{MultiProcessTestCase.TEST_ERROR_EXIT_CODE}')
|
||||
# Send error to parent process.
|
||||
parent_pipe.send(traceback.format_exc())
|
||||
sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE)
|
||||
finally:
|
||||
if signal_send_pipe is not None:
|
||||
signal_send_pipe.send(None)
|
||||
|
||||
assert event_listener_thread is not None
|
||||
event_listener_thread.join()
|
||||
# Close pipe after done with test.
|
||||
parent_pipe.close()
|
||||
|
||||
def _get_timedout_process_traceback(self) -> None:
|
||||
pipes = []
|
||||
for i, process in enumerate(self.processes):
|
||||
if process.exitcode is None:
|
||||
pipe = self.pid_to_pipe[process.pid]
|
||||
try:
|
||||
pipe.send(MultiProcessTestCase.Event.GET_TRACEBACK)
|
||||
pipes.append((i, pipe))
|
||||
except ConnectionError as e:
|
||||
logger.error(
|
||||
'Encountered error while trying to get traceback '
|
||||
f'for process {i}: {e}')
|
||||
|
||||
# Wait for results.
|
||||
for rank, pipe in pipes:
|
||||
try:
|
||||
# Wait for traceback
|
||||
if pipe.poll(5):
|
||||
if pipe.closed:
|
||||
logger.info(
|
||||
f'Pipe closed for process {rank}, cannot retrieve '
|
||||
'traceback')
|
||||
continue
|
||||
|
||||
traceback = pipe.recv()
|
||||
logger.error(f'Process {rank} timed out with traceback: '
|
||||
f'\n\n{traceback}')
|
||||
else:
|
||||
logger.error('Could not retrieve traceback for timed out '
|
||||
f'process: {rank}')
|
||||
except ConnectionError as e:
|
||||
logger.error(
|
||||
'Encountered error while trying to get traceback for '
|
||||
f'process {rank}: {e}')
|
||||
|
||||
def _join_processes(self, fn) -> None:
|
||||
start_time = time.time()
|
||||
subprocess_error = False
|
||||
try:
|
||||
while True:
|
||||
# check to see if any subprocess exited with an error early.
|
||||
for (i, p) in enumerate(self.processes):
|
||||
# This is the exit code processes exit with if they
|
||||
# encountered an exception.
|
||||
if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE:
|
||||
print(
|
||||
f'Process {i} terminated with exit code '
|
||||
f'{p.exitcode}, terminating remaining processes.')
|
||||
_active_children = active_children()
|
||||
for ac in _active_children:
|
||||
ac.terminate()
|
||||
subprocess_error = True
|
||||
break
|
||||
if subprocess_error:
|
||||
break
|
||||
# All processes have joined cleanly if they all a valid
|
||||
# exitcode
|
||||
if all([p.exitcode is not None for p in self.processes]):
|
||||
break
|
||||
# Check if we should time out the test. If so, we terminate
|
||||
# each process.
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > self.timeout:
|
||||
self._get_timedout_process_traceback()
|
||||
print(f'Timing out after {self.timeout} seconds and '
|
||||
'killing subprocesses.')
|
||||
for p in self.processes:
|
||||
p.terminate()
|
||||
break
|
||||
# Sleep to avoid excessive busy polling.
|
||||
time.sleep(0.1)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
if fn in self.skip_return_code_checks:
|
||||
self._check_no_test_errors(elapsed_time)
|
||||
else:
|
||||
self._check_return_codes(elapsed_time)
|
||||
finally:
|
||||
# Close all pipes
|
||||
for pid, pipe in self.pid_to_pipe.items():
|
||||
pipe.close()
|
||||
|
||||
def _check_no_test_errors(self, elapsed_time) -> None:
|
||||
"""Checks that we didn't have any errors thrown in the child
|
||||
processes."""
|
||||
for i, p in enumerate(self.processes):
|
||||
if p.exitcode is None:
|
||||
raise RuntimeError(
|
||||
'Process {} timed out after {} seconds'.format(
|
||||
i, elapsed_time))
|
||||
self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode)
|
||||
|
||||
def _check_return_codes(self, elapsed_time) -> None:
|
||||
"""Checks that the return codes of all spawned processes match, and
|
||||
skips tests if they returned a return code indicating a skipping
|
||||
condition."""
|
||||
first_process = self.processes[0]
|
||||
# first, we check if there are errors in actual processes
|
||||
# (via TEST_ERROR_EXIT CODE), and raise an exception for those.
|
||||
# the reason we do this is to attempt to raise a more helpful error
|
||||
# message than "Process x terminated/timed out"
|
||||
# TODO: we should pipe the exception of the failed subprocess here.
|
||||
# Currently, the actual exception is displayed as a logging output.
|
||||
errored_processes = [
|
||||
(i, p) for i, p in enumerate(self.processes)
|
||||
if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE
|
||||
]
|
||||
if errored_processes:
|
||||
error = ''
|
||||
for i, process in errored_processes:
|
||||
# Get error from pipe.
|
||||
error_message = self.pid_to_pipe[process.pid].recv()
|
||||
error += (
|
||||
'Process {} exited with error code {} and exception:\n{}\n'
|
||||
.format(i, MultiProcessTestCase.TEST_ERROR_EXIT_CODE,
|
||||
error_message))
|
||||
|
||||
raise RuntimeError(error)
|
||||
# If no process exited uncleanly, we check for timeouts, and then
|
||||
# ensure each process exited cleanly.
|
||||
for i, p in enumerate(self.processes):
|
||||
if p.exitcode is None:
|
||||
raise RuntimeError(
|
||||
f'Process {i} terminated or timed out after '
|
||||
'{elapsed_time} seconds')
|
||||
self.assertEqual(
|
||||
p.exitcode,
|
||||
first_process.exitcode,
|
||||
msg=f'Expect process {i} exit code to match Process 0 exit '
|
||||
'code of {first_process.exitcode}, but got {p.exitcode}')
|
||||
for skip in TEST_SKIPS.values():
|
||||
if first_process.exitcode == skip.exit_code:
|
||||
raise unittest.SkipTest(skip.message)
|
||||
self.assertEqual(
|
||||
first_process.exitcode,
|
||||
0,
|
||||
msg=f'Expected zero exit code but got {first_process.exitcode} '
|
||||
f'for pid: {first_process.pid}')
|
||||
|
||||
@property
|
||||
def is_master(self) -> bool:
|
||||
return self.rank == 0
|
Loading…
x
Reference in New Issue
Block a user