[Feature ] Add progressbar rich (#1157)

This commit is contained in:
王永韬 2023-08-30 20:10:07 +08:00 committed by GitHub
parent f24144d317
commit 0939d95c93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 255 additions and 18 deletions

View File

@ -14,6 +14,7 @@ from .path import (check_file_exist, fopen, is_abs, is_filepath,
mkdir_or_exist, scandir, symlink)
from .progressbar import (ProgressBar, track_iter_progress,
track_parallel_progress, track_progress)
from .progressbar_rich import track_progress_rich
from .timer import Timer, TimerError, check_time
from .version_utils import digit_version, get_git_hash
@ -29,5 +30,5 @@ __all__ = [
'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time',
'TimerError', 'ProgressBar', 'track_iter_progress',
'track_parallel_progress', 'track_progress', 'deprecated_function',
'apply_to', 'get_object_from_string'
'apply_to', 'track_progress_rich', 'get_object_from_string'
]

View File

@ -3,14 +3,35 @@ import sys
from collections.abc import Iterable
from multiprocessing import Pool
from shutil import get_terminal_size
from typing import Callable, Union
from .timer import Timer
class ProgressBar:
"""A progress bar which can print the progress."""
"""A progress bar which can print the progress.
def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
Args:
task_num (int): Number of total steps. Defaults to 0.
bar_width (int): Width of the progress bar. Defaults to 50.
start (bool): Whether to start the progress bar in the constructor.
Defaults to True.
file (callable): Progress bar output mode. Defaults to "sys.stdout".
Examples:
>>> import mmengine
>>> import time
>>> bar = mmengine.ProgressBar(10)
>>> for i in range(10):
>>> bar.update()
>>> time.sleep(1)
"""
def __init__(self,
task_num: int = 0,
bar_width: int = 50,
start: bool = True,
file=sys.stdout):
self.task_num = task_num
self.bar_width = bar_width
self.completed = 0
@ -32,7 +53,12 @@ class ProgressBar:
self.file.flush()
self.timer = Timer()
def update(self, num_tasks=1):
def update(self, num_tasks: int = 1):
"""update progressbar.
Args:
num_tasks (int): Update step size.
"""
assert num_tasks > 0
self.completed += num_tasks
elapsed = self.timer.since_start()
@ -61,7 +87,11 @@ class ProgressBar:
self.file.flush()
def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
def track_progress(func: Callable,
tasks: Union[list, Iterable],
bar_width: int = 50,
file=sys.stdout,
**kwargs):
"""Track the progress of tasks execution with a progress bar.
Tasks are done with a simple for-loop.
@ -81,7 +111,7 @@ def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
assert isinstance(tasks[1], int)
task_num = tasks[1]
tasks = tasks[0]
elif isinstance(tasks, Iterable):
elif isinstance(tasks, list):
task_num = len(tasks)
else:
raise TypeError(
@ -106,15 +136,15 @@ def init_pool(process_num, initializer=None, initargs=None):
return Pool(process_num, initializer, initargs)
def track_parallel_progress(func,
tasks,
nproc,
initializer=None,
initargs=None,
bar_width=50,
chunksize=1,
skip_first=False,
keep_order=True,
def track_parallel_progress(func: Callable,
tasks: Union[list, Iterable],
nproc: int,
initializer: Callable = None,
initargs: tuple = None,
bar_width: int = 50,
chunksize: int = 1,
skip_first: bool = False,
keep_order: bool = True,
file=sys.stdout):
"""Track the progress of parallel task execution with a progress bar.
@ -147,7 +177,7 @@ def track_parallel_progress(func,
assert isinstance(tasks[1], int)
task_num = tasks[1]
tasks = tasks[0]
elif isinstance(tasks, Iterable):
elif isinstance(tasks, list):
task_num = len(tasks)
else:
raise TypeError(
@ -176,7 +206,9 @@ def track_parallel_progress(func,
return results
def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
def track_iter_progress(tasks: Union[list, Iterable],
bar_width: int = 50,
file=sys.stdout):
"""Track the progress of tasks iteration or enumeration with a progress
bar.
@ -196,7 +228,7 @@ def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
assert isinstance(tasks[1], int)
task_num = tasks[1]
tasks = tasks[0]
elif isinstance(tasks, Iterable):
elif isinstance(tasks, list):
task_num = len(tasks)
else:
raise TypeError(

View File

@ -0,0 +1,151 @@
# Copyright (c) OpenMMLab. All rights reserved.
from multiprocessing import Pool
from typing import Callable, Iterable, Sized
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
TaskProgressColumn, TextColumn, TimeRemainingColumn)
from rich.text import Text
class _Worker:
"""Function wrapper for ``track_progress_rich``"""
def __init__(self, func) -> None:
self.func = func
def __call__(self, inputs):
inputs, idx = inputs
if not isinstance(inputs, (tuple, list)):
inputs = (inputs, )
return self.func(*inputs), idx
class _SkipFirstTimeRemainingColumn(TimeRemainingColumn):
"""Skip calculating remaining time for the first few times.
Args:
skip_times (int): The number of times to skip. Defaults to 0.
"""
def __init__(self, *args, skip_times=0, **kwargs):
super().__init__(*args, **kwargs)
self.skip_times = skip_times
def render(self, task: Task) -> Text:
"""Show time remaining."""
if task.completed <= self.skip_times:
return Text('-:--:--', style='progress.remaining')
return super().render(task)
def _tasks_with_index(tasks):
"""Add index to tasks."""
for idx, task in enumerate(tasks):
yield task, idx
def track_progress_rich(func: Callable,
tasks: Iterable = tuple(),
task_num: int = None,
nproc: int = 1,
chunksize: int = 1,
description: str = 'Processing',
color: str = 'blue') -> list:
"""Track the progress of parallel task execution with a progress bar. The
built-in :mod:`multiprocessing` module is used for process pools and tasks
are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
Args:
func (callable): The function to be applied to each task.
tasks (Iterable or Sized): A tuple of tasks. There are several cases
for different format tasks:
- When ``func`` accepts no arguments: tasks should be an empty
tuple, and ``task_num`` must be specified.
- When ``func`` accepts only one argument: tasks should be a tuple
containing the argument.
- When ``func`` accepts multiple arguments: tasks should be a
tuple, with each element representing a set of arguments.
If an element is a ``dict``, it will be parsed as a set of
keyword-only arguments.
Defaults to an empty tuple.
task_num (int, optional): If ``tasks`` is an iterator which does not
have length, the number of tasks can be provided by ``task_num``.
Defaults to None.
nproc (int): Process (worker) number, if nuproc is 1,
use single process. Defaults to 1.
chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
Defaults to 1.
description (str): The description of progress bar.
Defaults to "Process".
color (str): The color of progress bar. Defaults to "blue".
Examples:
>>> import time
>>> def func(x):
... time.sleep(1)
... return x**2
>>> track_progress_rich(func, range(10), nproc=2)
Returns:
list: The task results.
"""
if not callable(func):
raise TypeError('func must be a callable object')
if not isinstance(tasks, Iterable):
raise TypeError(
f'tasks must be an iterable object, but got {type(tasks)}')
if isinstance(tasks, Sized):
if len(tasks) == 0:
if task_num is None:
raise ValueError('If tasks is an empty iterable, '
'task_num must be set')
else:
tasks = tuple(tuple() for _ in range(task_num))
else:
if task_num is not None and task_num != len(tasks):
raise ValueError('task_num does not match the length of tasks')
task_num = len(tasks)
if nproc <= 0:
raise ValueError('nproc must be a positive number')
skip_times = nproc * chunksize if nproc > 1 else 0
prog_bar = Progress(
TextColumn('{task.description}'),
BarColumn(),
_SkipFirstTimeRemainingColumn(skip_times=skip_times),
MofNCompleteColumn(),
TaskProgressColumn(show_speed=True),
)
worker = _Worker(func)
task_id = prog_bar.add_task(
total=task_num, color=color, description=description)
tasks = _tasks_with_index(tasks)
# Use single process when nproc is 1, else use multiprocess.
with prog_bar:
if nproc == 1:
results = []
for task in tasks:
results.append(worker(task)[0])
prog_bar.update(task_id, advance=1, refresh=True)
else:
with Pool(nproc) as pool:
results = []
unordered_results = []
gen = pool.imap_unordered(worker, tasks, chunksize)
try:
for result in gen:
result, idx = result
unordered_results.append((result, idx))
results.append(None)
prog_bar.update(task_id, advance=1, refresh=True)
except Exception as e:
prog_bar.stop()
raise e
for result, idx in unordered_results:
results[idx] = result
return results

View File

@ -0,0 +1,53 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
from mmengine.utils import track_progress_rich
def foo():
return 1
def foo1(x):
return x
def foo2(x, y):
return x, y
def test_progressbar_rich_exception():
tasks = [1] * 10
# Valid func
with pytest.raises(TypeError):
track_progress_rich(1, tasks)
# invalid task
with pytest.raises(TypeError):
track_progress_rich(foo1, 1)
# mismatched task number
with pytest.raises(ValueError):
track_progress_rich(foo1, tasks, task_num=9)
# invalid proc
with pytest.raises(ValueError):
track_progress_rich(foo1, tasks, nproc=0)
# empty tasks and task_num is None
with pytest.raises(ValueError):
track_progress_rich(foo1, nproc=0)
@pytest.mark.parametrize('nproc', [1, 2])
def test_progressbar_rich(nproc):
# empty tasks
results = track_progress_rich(foo, nproc=nproc, task_num=10)
assert results == [1] * 10
# Ordered results
# foo1
tasks_ = [i for i in range(10)]
for tasks in (tasks_, iter(tasks_)):
results = track_progress_rich(foo1, tasks, nproc=nproc)
assert results == tasks_
# foo2
tasks_ = [(i, i + 1) for i in range(10)]
for tasks in (tasks_, iter(tasks_)):
results = track_progress_rich(foo2, tasks, nproc=nproc)
assert results == tasks_