Fix the type check of tasks in progress bar (#1340)
parent
8c934d2681
commit
9aa883a24c
|
@ -3,7 +3,7 @@ import sys
|
|||
from collections.abc import Iterable
|
||||
from multiprocessing import Pool
|
||||
from shutil import get_terminal_size
|
||||
from typing import Callable, Union
|
||||
from typing import Callable, Sequence
|
||||
|
||||
from .timer import Timer
|
||||
|
||||
|
@ -88,7 +88,7 @@ class ProgressBar:
|
|||
|
||||
|
||||
def track_progress(func: Callable,
|
||||
tasks: Union[list, Iterable],
|
||||
tasks: Sequence,
|
||||
bar_width: int = 50,
|
||||
file=sys.stdout,
|
||||
**kwargs):
|
||||
|
@ -98,8 +98,10 @@ def track_progress(func: Callable,
|
|||
|
||||
Args:
|
||||
func (callable): The function to be applied to each task.
|
||||
tasks (list or tuple[Iterable, int]): A list of tasks or
|
||||
(tasks, total num).
|
||||
tasks (Sequence): If tasks is a tuple, it must contain two elements,
|
||||
the first being the tasks to be completed and the other being the
|
||||
number of tasks. If it is not a tuple, it represents the tasks to
|
||||
be completed.
|
||||
bar_width (int): Width of progress bar.
|
||||
|
||||
Returns:
|
||||
|
@ -110,13 +112,13 @@ def track_progress(func: Callable,
|
|||
assert isinstance(tasks[0], Iterable)
|
||||
assert isinstance(tasks[1], int)
|
||||
task_num = tasks[1]
|
||||
tasks = tasks[0]
|
||||
elif isinstance(tasks, list):
|
||||
tasks = tasks[0] # type: ignore
|
||||
elif isinstance(tasks, Sequence):
|
||||
task_num = len(tasks)
|
||||
else:
|
||||
raise TypeError(
|
||||
'"tasks" must be an iterable object or a (iterator, int) tuple, '
|
||||
f'but got {type(tasks)}')
|
||||
'"tasks" must be a tuple object or a sequence object, but got '
|
||||
f'{type(tasks)}')
|
||||
prog_bar = ProgressBar(task_num, bar_width, file=file)
|
||||
results = []
|
||||
for task in tasks:
|
||||
|
@ -138,7 +140,7 @@ def init_pool(process_num, initializer=None, initargs=None):
|
|||
|
||||
|
||||
def track_parallel_progress(func: Callable,
|
||||
tasks: Union[list, Iterable],
|
||||
tasks: Sequence,
|
||||
nproc: int,
|
||||
initializer: Callable = None,
|
||||
initargs: tuple = None,
|
||||
|
@ -154,8 +156,10 @@ def track_parallel_progress(func: Callable,
|
|||
|
||||
Args:
|
||||
func (callable): The function to be applied to each task.
|
||||
tasks (list or tuple[Iterable, int]): A list of tasks or
|
||||
(tasks, total num).
|
||||
tasks (Sequence): If tasks is a tuple, it must contain two elements,
|
||||
the first being the tasks to be completed and the other being the
|
||||
number of tasks. If it is not a tuple, it represents the tasks to
|
||||
be completed.
|
||||
nproc (int): Process (worker) number.
|
||||
initializer (None or callable): Refer to :class:`multiprocessing.Pool`
|
||||
for details.
|
||||
|
@ -177,13 +181,13 @@ def track_parallel_progress(func: Callable,
|
|||
assert isinstance(tasks[0], Iterable)
|
||||
assert isinstance(tasks[1], int)
|
||||
task_num = tasks[1]
|
||||
tasks = tasks[0]
|
||||
elif isinstance(tasks, list):
|
||||
tasks = tasks[0] # type: ignore
|
||||
elif isinstance(tasks, Sequence):
|
||||
task_num = len(tasks)
|
||||
else:
|
||||
raise TypeError(
|
||||
'"tasks" must be an iterable object or a (iterator, int) tuple, '
|
||||
f'but got {type(tasks)}')
|
||||
'"tasks" must be a tuple object or a sequence object, but got '
|
||||
f'{type(tasks)}')
|
||||
pool = init_pool(nproc, initializer, initargs)
|
||||
start = not skip_first
|
||||
task_num -= nproc * chunksize * int(skip_first)
|
||||
|
@ -208,17 +212,17 @@ def track_parallel_progress(func: Callable,
|
|||
return results
|
||||
|
||||
|
||||
def track_iter_progress(tasks: Union[list, Iterable],
|
||||
bar_width: int = 50,
|
||||
file=sys.stdout):
|
||||
def track_iter_progress(tasks: Sequence, bar_width: int = 50, file=sys.stdout):
|
||||
"""Track the progress of tasks iteration or enumeration with a progress
|
||||
bar.
|
||||
|
||||
Tasks are yielded with a simple for-loop.
|
||||
|
||||
Args:
|
||||
tasks (list or tuple[Iterable, int]): A list of tasks or
|
||||
(tasks, total num).
|
||||
tasks (Sequence): If tasks is a tuple, it must contain two elements,
|
||||
the first being the tasks to be completed and the other being the
|
||||
number of tasks. If it is not a tuple, it represents the tasks to
|
||||
be completed.
|
||||
bar_width (int): Width of progress bar.
|
||||
|
||||
Yields:
|
||||
|
@ -229,13 +233,13 @@ def track_iter_progress(tasks: Union[list, Iterable],
|
|||
assert isinstance(tasks[0], Iterable)
|
||||
assert isinstance(tasks[1], int)
|
||||
task_num = tasks[1]
|
||||
tasks = tasks[0]
|
||||
elif isinstance(tasks, list):
|
||||
tasks = tasks[0] # type: ignore
|
||||
elif isinstance(tasks, Sequence):
|
||||
task_num = len(tasks)
|
||||
else:
|
||||
raise TypeError(
|
||||
'"tasks" must be an iterable object or a (iterator, int) tuple, '
|
||||
f'but got {type(tasks)}')
|
||||
'"tasks" must be a tuple object or a sequence object, but got '
|
||||
f'{type(tasks)}')
|
||||
prog_bar = ProgressBar(task_num, bar_width, file=file)
|
||||
for task in tasks:
|
||||
yield task
|
||||
|
|
|
@ -91,100 +91,71 @@ def sleep_1s(num):
|
|||
return num
|
||||
|
||||
|
||||
@skipIf(
|
||||
platform.system() != 'Linux',
|
||||
reason='Only test `test_track_progress_list` in Linux')
|
||||
def test_track_progress_list():
|
||||
def return_itself(num):
|
||||
return num
|
||||
|
||||
|
||||
def test_track_progress():
|
||||
# tasks is a list
|
||||
out = StringIO()
|
||||
ret = mmengine.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out)
|
||||
assert out.getvalue() == (
|
||||
'[ ] 0/3, elapsed: 0s, ETA:'
|
||||
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
|
||||
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
|
||||
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
|
||||
if platform == 'Linux':
|
||||
assert out.getvalue() == (
|
||||
'[ ] 0/3, elapsed: 0s, ETA:'
|
||||
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
|
||||
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
|
||||
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
|
||||
assert ret == [1, 2, 3]
|
||||
|
||||
|
||||
@skipIf(
|
||||
platform.system() != 'Linux',
|
||||
reason='Only test `test_track_progress_iterator` in Linux')
|
||||
def test_track_progress_iterator():
|
||||
out = StringIO()
|
||||
# tasks is an iterable object
|
||||
ret = mmengine.track_progress(
|
||||
sleep_1s, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out)
|
||||
assert out.getvalue() == (
|
||||
'[ ] 0/3, elapsed: 0s, ETA:'
|
||||
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
|
||||
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
|
||||
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
|
||||
return_itself, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out)
|
||||
assert ret == [1, 2, 3]
|
||||
|
||||
# tasks is a range object
|
||||
ret = mmengine.track_progress(
|
||||
return_itself, range(1, 4), bar_width=3, file=out)
|
||||
assert ret == [1, 2, 3]
|
||||
|
||||
|
||||
@skipIf(
|
||||
platform.system() != 'Linux',
|
||||
reason='Only test `test_track_iter_progress` in Linux')
|
||||
def test_track_iter_progress():
|
||||
out = StringIO()
|
||||
ret = []
|
||||
for num in mmengine.track_iter_progress([1, 2, 3], bar_width=3, file=out):
|
||||
ret.append(sleep_1s(num))
|
||||
assert out.getvalue() == (
|
||||
'[ ] 0/3, elapsed: 0s, ETA:'
|
||||
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
|
||||
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
|
||||
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
|
||||
ret.append(num)
|
||||
|
||||
assert ret == [1, 2, 3]
|
||||
|
||||
|
||||
@skipIf(
|
||||
platform.system() != 'Linux',
|
||||
reason='Only test `test_track_enum_progress` in Linux')
|
||||
def test_track_enum_progress():
|
||||
out = StringIO()
|
||||
ret = []
|
||||
count = []
|
||||
for i, num in enumerate(
|
||||
mmengine.track_iter_progress([1, 2, 3], bar_width=3, file=out)):
|
||||
ret.append(sleep_1s(num))
|
||||
ret.append(num)
|
||||
count.append(i)
|
||||
assert out.getvalue() == (
|
||||
'[ ] 0/3, elapsed: 0s, ETA:'
|
||||
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
|
||||
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
|
||||
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
|
||||
assert ret == [1, 2, 3]
|
||||
assert count == [0, 1, 2]
|
||||
|
||||
# tasks is a range object
|
||||
res = mmengine.track_iter_progress(range(1, 4), bar_width=3, file=out)
|
||||
assert list(res) == [1, 2, 3]
|
||||
|
||||
@skipIf(
|
||||
platform.system() != 'Linux',
|
||||
reason='Only test `test_track_parallel_progress_list` in Linux')
|
||||
def test_track_parallel_progress_list():
|
||||
|
||||
def test_track_parallel_progress():
|
||||
# tasks is a list
|
||||
out = StringIO()
|
||||
results = mmengine.track_parallel_progress(
|
||||
sleep_1s, [1, 2, 3, 4], 2, bar_width=4, file=out)
|
||||
# The following cannot pass CI on Github Action
|
||||
# assert out.getvalue() == (
|
||||
# '[ ] 0/4, elapsed: 0s, ETA:'
|
||||
# '\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s'
|
||||
# '\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s'
|
||||
# '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s'
|
||||
# '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n')
|
||||
assert results == [1, 2, 3, 4]
|
||||
ret = mmengine.track_parallel_progress(
|
||||
return_itself, [1, 2, 3, 4], 2, bar_width=4, file=out)
|
||||
assert ret == [1, 2, 3, 4]
|
||||
|
||||
# tasks is an iterable object
|
||||
ret = mmengine.track_parallel_progress(
|
||||
return_itself, ((i for i in [1, 2, 3, 4]), 4),
|
||||
2,
|
||||
bar_width=4,
|
||||
file=out)
|
||||
assert ret == [1, 2, 3, 4]
|
||||
|
||||
@skipIf(
|
||||
platform.system() != 'Linux',
|
||||
reason='Only test `test_track_parallel_progress_iterator` in Linux')
|
||||
def test_track_parallel_progress_iterator():
|
||||
out = StringIO()
|
||||
results = mmengine.track_parallel_progress(
|
||||
sleep_1s, ((i for i in [1, 2, 3, 4]), 4), 2, bar_width=4, file=out)
|
||||
# The following cannot pass CI on Github Action
|
||||
# assert out.getvalue() == (
|
||||
# '[ ] 0/4, elapsed: 0s, ETA:'
|
||||
# '\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s'
|
||||
# '\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s'
|
||||
# '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s'
|
||||
# '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n')
|
||||
assert results == [1, 2, 3, 4]
|
||||
# tasks is a range object
|
||||
ret = mmengine.track_parallel_progress(
|
||||
return_itself, range(1, 5), 2, bar_width=4, file=out)
|
||||
assert ret == [1, 2, 3, 4]
|
||||
|
|
Loading…
Reference in New Issue