diff --git a/mmengine/utils/__init__.py b/mmengine/utils/__init__.py index ba89c4ff..3de90999 100644 --- a/mmengine/utils/__init__.py +++ b/mmengine/utils/__init__.py @@ -14,6 +14,7 @@ from .path import (check_file_exist, fopen, is_abs, is_filepath, mkdir_or_exist, scandir, symlink) from .progressbar import (ProgressBar, track_iter_progress, track_parallel_progress, track_progress) +from .progressbar_rich import track_progress_rich from .timer import Timer, TimerError, check_time from .version_utils import digit_version, get_git_hash @@ -29,5 +30,5 @@ __all__ = [ 'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time', 'TimerError', 'ProgressBar', 'track_iter_progress', 'track_parallel_progress', 'track_progress', 'deprecated_function', - 'apply_to', 'get_object_from_string' + 'apply_to', 'track_progress_rich', 'get_object_from_string' ] diff --git a/mmengine/utils/progressbar.py b/mmengine/utils/progressbar.py index 0062f670..34cdda1b 100644 --- a/mmengine/utils/progressbar.py +++ b/mmengine/utils/progressbar.py @@ -3,14 +3,35 @@ import sys from collections.abc import Iterable from multiprocessing import Pool from shutil import get_terminal_size +from typing import Callable, Union from .timer import Timer class ProgressBar: - """A progress bar which can print the progress.""" + """A progress bar which can print the progress. - def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout): + Args: + task_num (int): Number of total steps. Defaults to 0. + bar_width (int): Width of the progress bar. Defaults to 50. + start (bool): Whether to start the progress bar in the constructor. + Defaults to True. + file (callable): Progress bar output mode. Defaults to "sys.stdout". + + Examples: + >>> import mmengine + >>> import time + >>> bar = mmengine.ProgressBar(10) + >>> for i in range(10): + >>> bar.update() + >>> time.sleep(1) + """ + + def __init__(self, + task_num: int = 0, + bar_width: int = 50, + start: bool = True, + file=sys.stdout): self.task_num = task_num self.bar_width = bar_width self.completed = 0 @@ -32,7 +53,12 @@ class ProgressBar: self.file.flush() self.timer = Timer() - def update(self, num_tasks=1): + def update(self, num_tasks: int = 1): + """update progressbar. + + Args: + num_tasks (int): Update step size. + """ assert num_tasks > 0 self.completed += num_tasks elapsed = self.timer.since_start() @@ -61,7 +87,11 @@ class ProgressBar: self.file.flush() -def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs): +def track_progress(func: Callable, + tasks: Union[list, Iterable], + bar_width: int = 50, + file=sys.stdout, + **kwargs): """Track the progress of tasks execution with a progress bar. Tasks are done with a simple for-loop. @@ -81,7 +111,7 @@ def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs): assert isinstance(tasks[1], int) task_num = tasks[1] tasks = tasks[0] - elif isinstance(tasks, Iterable): + elif isinstance(tasks, list): task_num = len(tasks) else: raise TypeError( @@ -106,15 +136,15 @@ def init_pool(process_num, initializer=None, initargs=None): return Pool(process_num, initializer, initargs) -def track_parallel_progress(func, - tasks, - nproc, - initializer=None, - initargs=None, - bar_width=50, - chunksize=1, - skip_first=False, - keep_order=True, +def track_parallel_progress(func: Callable, + tasks: Union[list, Iterable], + nproc: int, + initializer: Callable = None, + initargs: tuple = None, + bar_width: int = 50, + chunksize: int = 1, + skip_first: bool = False, + keep_order: bool = True, file=sys.stdout): """Track the progress of parallel task execution with a progress bar. @@ -147,7 +177,7 @@ def track_parallel_progress(func, assert isinstance(tasks[1], int) task_num = tasks[1] tasks = tasks[0] - elif isinstance(tasks, Iterable): + elif isinstance(tasks, list): task_num = len(tasks) else: raise TypeError( @@ -176,7 +206,9 @@ def track_parallel_progress(func, return results -def track_iter_progress(tasks, bar_width=50, file=sys.stdout): +def track_iter_progress(tasks: Union[list, Iterable], + bar_width: int = 50, + file=sys.stdout): """Track the progress of tasks iteration or enumeration with a progress bar. @@ -196,7 +228,7 @@ def track_iter_progress(tasks, bar_width=50, file=sys.stdout): assert isinstance(tasks[1], int) task_num = tasks[1] tasks = tasks[0] - elif isinstance(tasks, Iterable): + elif isinstance(tasks, list): task_num = len(tasks) else: raise TypeError( diff --git a/mmengine/utils/progressbar_rich.py b/mmengine/utils/progressbar_rich.py new file mode 100644 index 00000000..c126866b --- /dev/null +++ b/mmengine/utils/progressbar_rich.py @@ -0,0 +1,151 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from multiprocessing import Pool +from typing import Callable, Iterable, Sized + +from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task, + TaskProgressColumn, TextColumn, TimeRemainingColumn) +from rich.text import Text + + +class _Worker: + """Function wrapper for ``track_progress_rich``""" + + def __init__(self, func) -> None: + self.func = func + + def __call__(self, inputs): + inputs, idx = inputs + if not isinstance(inputs, (tuple, list)): + inputs = (inputs, ) + + return self.func(*inputs), idx + + +class _SkipFirstTimeRemainingColumn(TimeRemainingColumn): + """Skip calculating remaining time for the first few times. + + Args: + skip_times (int): The number of times to skip. Defaults to 0. + """ + + def __init__(self, *args, skip_times=0, **kwargs): + super().__init__(*args, **kwargs) + self.skip_times = skip_times + + def render(self, task: Task) -> Text: + """Show time remaining.""" + if task.completed <= self.skip_times: + return Text('-:--:--', style='progress.remaining') + return super().render(task) + + +def _tasks_with_index(tasks): + """Add index to tasks.""" + for idx, task in enumerate(tasks): + yield task, idx + + +def track_progress_rich(func: Callable, + tasks: Iterable = tuple(), + task_num: int = None, + nproc: int = 1, + chunksize: int = 1, + description: str = 'Processing', + color: str = 'blue') -> list: + """Track the progress of parallel task execution with a progress bar. The + built-in :mod:`multiprocessing` module is used for process pools and tasks + are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. + + Args: + func (callable): The function to be applied to each task. + tasks (Iterable or Sized): A tuple of tasks. There are several cases + for different format tasks: + - When ``func`` accepts no arguments: tasks should be an empty + tuple, and ``task_num`` must be specified. + - When ``func`` accepts only one argument: tasks should be a tuple + containing the argument. + - When ``func`` accepts multiple arguments: tasks should be a + tuple, with each element representing a set of arguments. + If an element is a ``dict``, it will be parsed as a set of + keyword-only arguments. + Defaults to an empty tuple. + task_num (int, optional): If ``tasks`` is an iterator which does not + have length, the number of tasks can be provided by ``task_num``. + Defaults to None. + nproc (int): Process (worker) number, if nuproc is 1, + use single process. Defaults to 1. + chunksize (int): Refer to :class:`multiprocessing.Pool` for details. + Defaults to 1. + description (str): The description of progress bar. + Defaults to "Process". + color (str): The color of progress bar. Defaults to "blue". + + Examples: + >>> import time + + >>> def func(x): + ... time.sleep(1) + ... return x**2 + >>> track_progress_rich(func, range(10), nproc=2) + + Returns: + list: The task results. + """ + if not callable(func): + raise TypeError('func must be a callable object') + if not isinstance(tasks, Iterable): + raise TypeError( + f'tasks must be an iterable object, but got {type(tasks)}') + if isinstance(tasks, Sized): + if len(tasks) == 0: + if task_num is None: + raise ValueError('If tasks is an empty iterable, ' + 'task_num must be set') + else: + tasks = tuple(tuple() for _ in range(task_num)) + else: + if task_num is not None and task_num != len(tasks): + raise ValueError('task_num does not match the length of tasks') + task_num = len(tasks) + + if nproc <= 0: + raise ValueError('nproc must be a positive number') + + skip_times = nproc * chunksize if nproc > 1 else 0 + prog_bar = Progress( + TextColumn('{task.description}'), + BarColumn(), + _SkipFirstTimeRemainingColumn(skip_times=skip_times), + MofNCompleteColumn(), + TaskProgressColumn(show_speed=True), + ) + + worker = _Worker(func) + task_id = prog_bar.add_task( + total=task_num, color=color, description=description) + tasks = _tasks_with_index(tasks) + + # Use single process when nproc is 1, else use multiprocess. + with prog_bar: + if nproc == 1: + results = [] + for task in tasks: + results.append(worker(task)[0]) + prog_bar.update(task_id, advance=1, refresh=True) + else: + with Pool(nproc) as pool: + results = [] + unordered_results = [] + gen = pool.imap_unordered(worker, tasks, chunksize) + try: + for result in gen: + result, idx = result + unordered_results.append((result, idx)) + results.append(None) + prog_bar.update(task_id, advance=1, refresh=True) + except Exception as e: + prog_bar.stop() + raise e + for result, idx in unordered_results: + results[idx] = result + return results diff --git a/tests/test_utils/test_progressbar_rich.py b/tests/test_utils/test_progressbar_rich.py new file mode 100644 index 00000000..9c507bf6 --- /dev/null +++ b/tests/test_utils/test_progressbar_rich.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest + +from mmengine.utils import track_progress_rich + + +def foo(): + return 1 + + +def foo1(x): + return x + + +def foo2(x, y): + return x, y + + +def test_progressbar_rich_exception(): + tasks = [1] * 10 + # Valid func + with pytest.raises(TypeError): + track_progress_rich(1, tasks) + # invalid task + with pytest.raises(TypeError): + track_progress_rich(foo1, 1) + # mismatched task number + with pytest.raises(ValueError): + track_progress_rich(foo1, tasks, task_num=9) + # invalid proc + with pytest.raises(ValueError): + track_progress_rich(foo1, tasks, nproc=0) + # empty tasks and task_num is None + with pytest.raises(ValueError): + track_progress_rich(foo1, nproc=0) + + +@pytest.mark.parametrize('nproc', [1, 2]) +def test_progressbar_rich(nproc): + # empty tasks + results = track_progress_rich(foo, nproc=nproc, task_num=10) + assert results == [1] * 10 + # Ordered results + # foo1 + tasks_ = [i for i in range(10)] + for tasks in (tasks_, iter(tasks_)): + results = track_progress_rich(foo1, tasks, nproc=nproc) + assert results == tasks_ + # foo2 + tasks_ = [(i, i + 1) for i in range(10)] + for tasks in (tasks_, iter(tasks_)): + results = track_progress_rich(foo2, tasks, nproc=nproc) + assert results == tasks_