Fix the type check of tasks in progress bar (#1340)

pull/1341/head
Zaida Zhou 2023-09-04 19:27:58 +08:00 committed by GitHub
parent 8c934d2681
commit 9aa883a24c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 70 additions and 95 deletions

View File

@ -3,7 +3,7 @@ import sys
from collections.abc import Iterable
from multiprocessing import Pool
from shutil import get_terminal_size
from typing import Callable, Union
from typing import Callable, Sequence
from .timer import Timer
@ -88,7 +88,7 @@ class ProgressBar:
def track_progress(func: Callable,
tasks: Union[list, Iterable],
tasks: Sequence,
bar_width: int = 50,
file=sys.stdout,
**kwargs):
@ -98,8 +98,10 @@ def track_progress(func: Callable,
Args:
func (callable): The function to be applied to each task.
tasks (list or tuple[Iterable, int]): A list of tasks or
(tasks, total num).
tasks (Sequence): If tasks is a tuple, it must contain two elements,
the first being the tasks to be completed and the other being the
number of tasks. If it is not a tuple, it represents the tasks to
be completed.
bar_width (int): Width of progress bar.
Returns:
@ -110,13 +112,13 @@ def track_progress(func: Callable,
assert isinstance(tasks[0], Iterable)
assert isinstance(tasks[1], int)
task_num = tasks[1]
tasks = tasks[0]
elif isinstance(tasks, list):
tasks = tasks[0] # type: ignore
elif isinstance(tasks, Sequence):
task_num = len(tasks)
else:
raise TypeError(
'"tasks" must be an iterable object or a (iterator, int) tuple, '
f'but got {type(tasks)}')
'"tasks" must be a tuple object or a sequence object, but got '
f'{type(tasks)}')
prog_bar = ProgressBar(task_num, bar_width, file=file)
results = []
for task in tasks:
@ -138,7 +140,7 @@ def init_pool(process_num, initializer=None, initargs=None):
def track_parallel_progress(func: Callable,
tasks: Union[list, Iterable],
tasks: Sequence,
nproc: int,
initializer: Callable = None,
initargs: tuple = None,
@ -154,8 +156,10 @@ def track_parallel_progress(func: Callable,
Args:
func (callable): The function to be applied to each task.
tasks (list or tuple[Iterable, int]): A list of tasks or
(tasks, total num).
tasks (Sequence): If tasks is a tuple, it must contain two elements,
the first being the tasks to be completed and the other being the
number of tasks. If it is not a tuple, it represents the tasks to
be completed.
nproc (int): Process (worker) number.
initializer (None or callable): Refer to :class:`multiprocessing.Pool`
for details.
@ -177,13 +181,13 @@ def track_parallel_progress(func: Callable,
assert isinstance(tasks[0], Iterable)
assert isinstance(tasks[1], int)
task_num = tasks[1]
tasks = tasks[0]
elif isinstance(tasks, list):
tasks = tasks[0] # type: ignore
elif isinstance(tasks, Sequence):
task_num = len(tasks)
else:
raise TypeError(
'"tasks" must be an iterable object or a (iterator, int) tuple, '
f'but got {type(tasks)}')
'"tasks" must be a tuple object or a sequence object, but got '
f'{type(tasks)}')
pool = init_pool(nproc, initializer, initargs)
start = not skip_first
task_num -= nproc * chunksize * int(skip_first)
@ -208,17 +212,17 @@ def track_parallel_progress(func: Callable,
return results
def track_iter_progress(tasks: Union[list, Iterable],
bar_width: int = 50,
file=sys.stdout):
def track_iter_progress(tasks: Sequence, bar_width: int = 50, file=sys.stdout):
"""Track the progress of tasks iteration or enumeration with a progress
bar.
Tasks are yielded with a simple for-loop.
Args:
tasks (list or tuple[Iterable, int]): A list of tasks or
(tasks, total num).
tasks (Sequence): If tasks is a tuple, it must contain two elements,
the first being the tasks to be completed and the other being the
number of tasks. If it is not a tuple, it represents the tasks to
be completed.
bar_width (int): Width of progress bar.
Yields:
@ -229,13 +233,13 @@ def track_iter_progress(tasks: Union[list, Iterable],
assert isinstance(tasks[0], Iterable)
assert isinstance(tasks[1], int)
task_num = tasks[1]
tasks = tasks[0]
elif isinstance(tasks, list):
tasks = tasks[0] # type: ignore
elif isinstance(tasks, Sequence):
task_num = len(tasks)
else:
raise TypeError(
'"tasks" must be an iterable object or a (iterator, int) tuple, '
f'but got {type(tasks)}')
'"tasks" must be a tuple object or a sequence object, but got '
f'{type(tasks)}')
prog_bar = ProgressBar(task_num, bar_width, file=file)
for task in tasks:
yield task

View File

@ -91,100 +91,71 @@ def sleep_1s(num):
return num
@skipIf(
platform.system() != 'Linux',
reason='Only test `test_track_progress_list` in Linux')
def test_track_progress_list():
def return_itself(num):
return num
def test_track_progress():
# tasks is a list
out = StringIO()
ret = mmengine.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out)
assert out.getvalue() == (
'[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
if platform == 'Linux':
assert out.getvalue() == (
'[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
assert ret == [1, 2, 3]
@skipIf(
platform.system() != 'Linux',
reason='Only test `test_track_progress_iterator` in Linux')
def test_track_progress_iterator():
out = StringIO()
# tasks is an iterable object
ret = mmengine.track_progress(
sleep_1s, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out)
assert out.getvalue() == (
'[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
return_itself, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out)
assert ret == [1, 2, 3]
# tasks is a range object
ret = mmengine.track_progress(
return_itself, range(1, 4), bar_width=3, file=out)
assert ret == [1, 2, 3]
@skipIf(
platform.system() != 'Linux',
reason='Only test `test_track_iter_progress` in Linux')
def test_track_iter_progress():
out = StringIO()
ret = []
for num in mmengine.track_iter_progress([1, 2, 3], bar_width=3, file=out):
ret.append(sleep_1s(num))
assert out.getvalue() == (
'[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
ret.append(num)
assert ret == [1, 2, 3]
@skipIf(
platform.system() != 'Linux',
reason='Only test `test_track_enum_progress` in Linux')
def test_track_enum_progress():
out = StringIO()
ret = []
count = []
for i, num in enumerate(
mmengine.track_iter_progress([1, 2, 3], bar_width=3, file=out)):
ret.append(sleep_1s(num))
ret.append(num)
count.append(i)
assert out.getvalue() == (
'[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
assert ret == [1, 2, 3]
assert count == [0, 1, 2]
# tasks is a range object
res = mmengine.track_iter_progress(range(1, 4), bar_width=3, file=out)
assert list(res) == [1, 2, 3]
@skipIf(
platform.system() != 'Linux',
reason='Only test `test_track_parallel_progress_list` in Linux')
def test_track_parallel_progress_list():
def test_track_parallel_progress():
# tasks is a list
out = StringIO()
results = mmengine.track_parallel_progress(
sleep_1s, [1, 2, 3, 4], 2, bar_width=4, file=out)
# The following cannot pass CI on Github Action
# assert out.getvalue() == (
# '[ ] 0/4, elapsed: 0s, ETA:'
# '\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s'
# '\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s'
# '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s'
# '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n')
assert results == [1, 2, 3, 4]
ret = mmengine.track_parallel_progress(
return_itself, [1, 2, 3, 4], 2, bar_width=4, file=out)
assert ret == [1, 2, 3, 4]
# tasks is an iterable object
ret = mmengine.track_parallel_progress(
return_itself, ((i for i in [1, 2, 3, 4]), 4),
2,
bar_width=4,
file=out)
assert ret == [1, 2, 3, 4]
@skipIf(
platform.system() != 'Linux',
reason='Only test `test_track_parallel_progress_iterator` in Linux')
def test_track_parallel_progress_iterator():
out = StringIO()
results = mmengine.track_parallel_progress(
sleep_1s, ((i for i in [1, 2, 3, 4]), 4), 2, bar_width=4, file=out)
# The following cannot pass CI on Github Action
# assert out.getvalue() == (
# '[ ] 0/4, elapsed: 0s, ETA:'
# '\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s'
# '\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s'
# '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s'
# '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n')
assert results == [1, 2, 3, 4]
# tasks is a range object
ret = mmengine.track_parallel_progress(
return_itself, range(1, 5), 2, bar_width=4, file=out)
assert ret == [1, 2, 3, 4]