mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature ] Add progressbar rich (#1157)
This commit is contained in:
parent
f24144d317
commit
0939d95c93
@ -14,6 +14,7 @@ from .path import (check_file_exist, fopen, is_abs, is_filepath,
|
|||||||
mkdir_or_exist, scandir, symlink)
|
mkdir_or_exist, scandir, symlink)
|
||||||
from .progressbar import (ProgressBar, track_iter_progress,
|
from .progressbar import (ProgressBar, track_iter_progress,
|
||||||
track_parallel_progress, track_progress)
|
track_parallel_progress, track_progress)
|
||||||
|
from .progressbar_rich import track_progress_rich
|
||||||
from .timer import Timer, TimerError, check_time
|
from .timer import Timer, TimerError, check_time
|
||||||
from .version_utils import digit_version, get_git_hash
|
from .version_utils import digit_version, get_git_hash
|
||||||
|
|
||||||
@ -29,5 +30,5 @@ __all__ = [
|
|||||||
'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time',
|
'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time',
|
||||||
'TimerError', 'ProgressBar', 'track_iter_progress',
|
'TimerError', 'ProgressBar', 'track_iter_progress',
|
||||||
'track_parallel_progress', 'track_progress', 'deprecated_function',
|
'track_parallel_progress', 'track_progress', 'deprecated_function',
|
||||||
'apply_to', 'get_object_from_string'
|
'apply_to', 'track_progress_rich', 'get_object_from_string'
|
||||||
]
|
]
|
||||||
|
@ -3,14 +3,35 @@ import sys
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from shutil import get_terminal_size
|
from shutil import get_terminal_size
|
||||||
|
from typing import Callable, Union
|
||||||
|
|
||||||
from .timer import Timer
|
from .timer import Timer
|
||||||
|
|
||||||
|
|
||||||
class ProgressBar:
|
class ProgressBar:
|
||||||
"""A progress bar which can print the progress."""
|
"""A progress bar which can print the progress.
|
||||||
|
|
||||||
def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
|
Args:
|
||||||
|
task_num (int): Number of total steps. Defaults to 0.
|
||||||
|
bar_width (int): Width of the progress bar. Defaults to 50.
|
||||||
|
start (bool): Whether to start the progress bar in the constructor.
|
||||||
|
Defaults to True.
|
||||||
|
file (callable): Progress bar output mode. Defaults to "sys.stdout".
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import mmengine
|
||||||
|
>>> import time
|
||||||
|
>>> bar = mmengine.ProgressBar(10)
|
||||||
|
>>> for i in range(10):
|
||||||
|
>>> bar.update()
|
||||||
|
>>> time.sleep(1)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
task_num: int = 0,
|
||||||
|
bar_width: int = 50,
|
||||||
|
start: bool = True,
|
||||||
|
file=sys.stdout):
|
||||||
self.task_num = task_num
|
self.task_num = task_num
|
||||||
self.bar_width = bar_width
|
self.bar_width = bar_width
|
||||||
self.completed = 0
|
self.completed = 0
|
||||||
@ -32,7 +53,12 @@ class ProgressBar:
|
|||||||
self.file.flush()
|
self.file.flush()
|
||||||
self.timer = Timer()
|
self.timer = Timer()
|
||||||
|
|
||||||
def update(self, num_tasks=1):
|
def update(self, num_tasks: int = 1):
|
||||||
|
"""update progressbar.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_tasks (int): Update step size.
|
||||||
|
"""
|
||||||
assert num_tasks > 0
|
assert num_tasks > 0
|
||||||
self.completed += num_tasks
|
self.completed += num_tasks
|
||||||
elapsed = self.timer.since_start()
|
elapsed = self.timer.since_start()
|
||||||
@ -61,7 +87,11 @@ class ProgressBar:
|
|||||||
self.file.flush()
|
self.file.flush()
|
||||||
|
|
||||||
|
|
||||||
def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
|
def track_progress(func: Callable,
|
||||||
|
tasks: Union[list, Iterable],
|
||||||
|
bar_width: int = 50,
|
||||||
|
file=sys.stdout,
|
||||||
|
**kwargs):
|
||||||
"""Track the progress of tasks execution with a progress bar.
|
"""Track the progress of tasks execution with a progress bar.
|
||||||
|
|
||||||
Tasks are done with a simple for-loop.
|
Tasks are done with a simple for-loop.
|
||||||
@ -81,7 +111,7 @@ def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
|
|||||||
assert isinstance(tasks[1], int)
|
assert isinstance(tasks[1], int)
|
||||||
task_num = tasks[1]
|
task_num = tasks[1]
|
||||||
tasks = tasks[0]
|
tasks = tasks[0]
|
||||||
elif isinstance(tasks, Iterable):
|
elif isinstance(tasks, list):
|
||||||
task_num = len(tasks)
|
task_num = len(tasks)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@ -106,15 +136,15 @@ def init_pool(process_num, initializer=None, initargs=None):
|
|||||||
return Pool(process_num, initializer, initargs)
|
return Pool(process_num, initializer, initargs)
|
||||||
|
|
||||||
|
|
||||||
def track_parallel_progress(func,
|
def track_parallel_progress(func: Callable,
|
||||||
tasks,
|
tasks: Union[list, Iterable],
|
||||||
nproc,
|
nproc: int,
|
||||||
initializer=None,
|
initializer: Callable = None,
|
||||||
initargs=None,
|
initargs: tuple = None,
|
||||||
bar_width=50,
|
bar_width: int = 50,
|
||||||
chunksize=1,
|
chunksize: int = 1,
|
||||||
skip_first=False,
|
skip_first: bool = False,
|
||||||
keep_order=True,
|
keep_order: bool = True,
|
||||||
file=sys.stdout):
|
file=sys.stdout):
|
||||||
"""Track the progress of parallel task execution with a progress bar.
|
"""Track the progress of parallel task execution with a progress bar.
|
||||||
|
|
||||||
@ -147,7 +177,7 @@ def track_parallel_progress(func,
|
|||||||
assert isinstance(tasks[1], int)
|
assert isinstance(tasks[1], int)
|
||||||
task_num = tasks[1]
|
task_num = tasks[1]
|
||||||
tasks = tasks[0]
|
tasks = tasks[0]
|
||||||
elif isinstance(tasks, Iterable):
|
elif isinstance(tasks, list):
|
||||||
task_num = len(tasks)
|
task_num = len(tasks)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@ -176,7 +206,9 @@ def track_parallel_progress(func,
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
|
def track_iter_progress(tasks: Union[list, Iterable],
|
||||||
|
bar_width: int = 50,
|
||||||
|
file=sys.stdout):
|
||||||
"""Track the progress of tasks iteration or enumeration with a progress
|
"""Track the progress of tasks iteration or enumeration with a progress
|
||||||
bar.
|
bar.
|
||||||
|
|
||||||
@ -196,7 +228,7 @@ def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
|
|||||||
assert isinstance(tasks[1], int)
|
assert isinstance(tasks[1], int)
|
||||||
task_num = tasks[1]
|
task_num = tasks[1]
|
||||||
tasks = tasks[0]
|
tasks = tasks[0]
|
||||||
elif isinstance(tasks, Iterable):
|
elif isinstance(tasks, list):
|
||||||
task_num = len(tasks)
|
task_num = len(tasks)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
151
mmengine/utils/progressbar_rich.py
Normal file
151
mmengine/utils/progressbar_rich.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from multiprocessing import Pool
|
||||||
|
from typing import Callable, Iterable, Sized
|
||||||
|
|
||||||
|
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
|
||||||
|
TaskProgressColumn, TextColumn, TimeRemainingColumn)
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
|
||||||
|
class _Worker:
|
||||||
|
"""Function wrapper for ``track_progress_rich``"""
|
||||||
|
|
||||||
|
def __init__(self, func) -> None:
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
def __call__(self, inputs):
|
||||||
|
inputs, idx = inputs
|
||||||
|
if not isinstance(inputs, (tuple, list)):
|
||||||
|
inputs = (inputs, )
|
||||||
|
|
||||||
|
return self.func(*inputs), idx
|
||||||
|
|
||||||
|
|
||||||
|
class _SkipFirstTimeRemainingColumn(TimeRemainingColumn):
|
||||||
|
"""Skip calculating remaining time for the first few times.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skip_times (int): The number of times to skip. Defaults to 0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, skip_times=0, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.skip_times = skip_times
|
||||||
|
|
||||||
|
def render(self, task: Task) -> Text:
|
||||||
|
"""Show time remaining."""
|
||||||
|
if task.completed <= self.skip_times:
|
||||||
|
return Text('-:--:--', style='progress.remaining')
|
||||||
|
return super().render(task)
|
||||||
|
|
||||||
|
|
||||||
|
def _tasks_with_index(tasks):
|
||||||
|
"""Add index to tasks."""
|
||||||
|
for idx, task in enumerate(tasks):
|
||||||
|
yield task, idx
|
||||||
|
|
||||||
|
|
||||||
|
def track_progress_rich(func: Callable,
|
||||||
|
tasks: Iterable = tuple(),
|
||||||
|
task_num: int = None,
|
||||||
|
nproc: int = 1,
|
||||||
|
chunksize: int = 1,
|
||||||
|
description: str = 'Processing',
|
||||||
|
color: str = 'blue') -> list:
|
||||||
|
"""Track the progress of parallel task execution with a progress bar. The
|
||||||
|
built-in :mod:`multiprocessing` module is used for process pools and tasks
|
||||||
|
are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (callable): The function to be applied to each task.
|
||||||
|
tasks (Iterable or Sized): A tuple of tasks. There are several cases
|
||||||
|
for different format tasks:
|
||||||
|
- When ``func`` accepts no arguments: tasks should be an empty
|
||||||
|
tuple, and ``task_num`` must be specified.
|
||||||
|
- When ``func`` accepts only one argument: tasks should be a tuple
|
||||||
|
containing the argument.
|
||||||
|
- When ``func`` accepts multiple arguments: tasks should be a
|
||||||
|
tuple, with each element representing a set of arguments.
|
||||||
|
If an element is a ``dict``, it will be parsed as a set of
|
||||||
|
keyword-only arguments.
|
||||||
|
Defaults to an empty tuple.
|
||||||
|
task_num (int, optional): If ``tasks`` is an iterator which does not
|
||||||
|
have length, the number of tasks can be provided by ``task_num``.
|
||||||
|
Defaults to None.
|
||||||
|
nproc (int): Process (worker) number, if nuproc is 1,
|
||||||
|
use single process. Defaults to 1.
|
||||||
|
chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
|
||||||
|
Defaults to 1.
|
||||||
|
description (str): The description of progress bar.
|
||||||
|
Defaults to "Process".
|
||||||
|
color (str): The color of progress bar. Defaults to "blue".
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import time
|
||||||
|
|
||||||
|
>>> def func(x):
|
||||||
|
... time.sleep(1)
|
||||||
|
... return x**2
|
||||||
|
>>> track_progress_rich(func, range(10), nproc=2)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: The task results.
|
||||||
|
"""
|
||||||
|
if not callable(func):
|
||||||
|
raise TypeError('func must be a callable object')
|
||||||
|
if not isinstance(tasks, Iterable):
|
||||||
|
raise TypeError(
|
||||||
|
f'tasks must be an iterable object, but got {type(tasks)}')
|
||||||
|
if isinstance(tasks, Sized):
|
||||||
|
if len(tasks) == 0:
|
||||||
|
if task_num is None:
|
||||||
|
raise ValueError('If tasks is an empty iterable, '
|
||||||
|
'task_num must be set')
|
||||||
|
else:
|
||||||
|
tasks = tuple(tuple() for _ in range(task_num))
|
||||||
|
else:
|
||||||
|
if task_num is not None and task_num != len(tasks):
|
||||||
|
raise ValueError('task_num does not match the length of tasks')
|
||||||
|
task_num = len(tasks)
|
||||||
|
|
||||||
|
if nproc <= 0:
|
||||||
|
raise ValueError('nproc must be a positive number')
|
||||||
|
|
||||||
|
skip_times = nproc * chunksize if nproc > 1 else 0
|
||||||
|
prog_bar = Progress(
|
||||||
|
TextColumn('{task.description}'),
|
||||||
|
BarColumn(),
|
||||||
|
_SkipFirstTimeRemainingColumn(skip_times=skip_times),
|
||||||
|
MofNCompleteColumn(),
|
||||||
|
TaskProgressColumn(show_speed=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
worker = _Worker(func)
|
||||||
|
task_id = prog_bar.add_task(
|
||||||
|
total=task_num, color=color, description=description)
|
||||||
|
tasks = _tasks_with_index(tasks)
|
||||||
|
|
||||||
|
# Use single process when nproc is 1, else use multiprocess.
|
||||||
|
with prog_bar:
|
||||||
|
if nproc == 1:
|
||||||
|
results = []
|
||||||
|
for task in tasks:
|
||||||
|
results.append(worker(task)[0])
|
||||||
|
prog_bar.update(task_id, advance=1, refresh=True)
|
||||||
|
else:
|
||||||
|
with Pool(nproc) as pool:
|
||||||
|
results = []
|
||||||
|
unordered_results = []
|
||||||
|
gen = pool.imap_unordered(worker, tasks, chunksize)
|
||||||
|
try:
|
||||||
|
for result in gen:
|
||||||
|
result, idx = result
|
||||||
|
unordered_results.append((result, idx))
|
||||||
|
results.append(None)
|
||||||
|
prog_bar.update(task_id, advance=1, refresh=True)
|
||||||
|
except Exception as e:
|
||||||
|
prog_bar.stop()
|
||||||
|
raise e
|
||||||
|
for result, idx in unordered_results:
|
||||||
|
results[idx] = result
|
||||||
|
return results
|
53
tests/test_utils/test_progressbar_rich.py
Normal file
53
tests/test_utils/test_progressbar_rich.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mmengine.utils import track_progress_rich
|
||||||
|
|
||||||
|
|
||||||
|
def foo():
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def foo1(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def foo2(x, y):
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
def test_progressbar_rich_exception():
|
||||||
|
tasks = [1] * 10
|
||||||
|
# Valid func
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
track_progress_rich(1, tasks)
|
||||||
|
# invalid task
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
track_progress_rich(foo1, 1)
|
||||||
|
# mismatched task number
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
track_progress_rich(foo1, tasks, task_num=9)
|
||||||
|
# invalid proc
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
track_progress_rich(foo1, tasks, nproc=0)
|
||||||
|
# empty tasks and task_num is None
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
track_progress_rich(foo1, nproc=0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('nproc', [1, 2])
|
||||||
|
def test_progressbar_rich(nproc):
|
||||||
|
# empty tasks
|
||||||
|
results = track_progress_rich(foo, nproc=nproc, task_num=10)
|
||||||
|
assert results == [1] * 10
|
||||||
|
# Ordered results
|
||||||
|
# foo1
|
||||||
|
tasks_ = [i for i in range(10)]
|
||||||
|
for tasks in (tasks_, iter(tasks_)):
|
||||||
|
results = track_progress_rich(foo1, tasks, nproc=nproc)
|
||||||
|
assert results == tasks_
|
||||||
|
# foo2
|
||||||
|
tasks_ = [(i, i + 1) for i in range(10)]
|
||||||
|
for tasks in (tasks_, iter(tasks_)):
|
||||||
|
results = track_progress_rich(foo2, tasks, nproc=nproc)
|
||||||
|
assert results == tasks_
|
Loading…
x
Reference in New Issue
Block a user