mirror of https://github.com/open-mmlab/mmcv.git
Track iter progress (#112)
* track progress of iter&enum * reformat * reformat with yapf * add unitest * add doc, and deprecate track_enum_progress * update docs & commentspull/117/head
parent
9de7927a69
commit
d5865e0cbd
|
@ -54,6 +54,25 @@ mmcv.track_parallel_progress(func, tasks, 8) # 8 workers
|
|||
|
||||

|
||||
|
||||
If you want to iterate or enumerate a list of items and track the progress, `track_iter_progress`
|
||||
is a good choice. It will display a progress bar to tell the progress and ETA.
|
||||
|
||||
```python
|
||||
import mmcv
|
||||
|
||||
tasks = [item_1, item_2, ..., item_n]
|
||||
|
||||
for task in mmcv.track_iter_progress(tasks):
|
||||
# do something like print
|
||||
print(task)
|
||||
|
||||
for i, task in enumerate(mmcv.track_iter_progress(tasks)):
|
||||
# do something like print
|
||||
print(i)
|
||||
print(task)
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Timer
|
||||
|
||||
|
|
|
@ -4,7 +4,8 @@ from .misc import (is_str, iter_cast, list_cast, tuple_cast, is_seq_of,
|
|||
check_prerequisites, requires_package, requires_executable)
|
||||
from .path import (is_filepath, fopen, check_file_exist, mkdir_or_exist,
|
||||
symlink, scandir, FileNotFoundError)
|
||||
from .progressbar import ProgressBar, track_progress, track_parallel_progress
|
||||
from .progressbar import (ProgressBar, track_progress, track_parallel_progress,
|
||||
track_iter_progress)
|
||||
from .timer import Timer, TimerError, check_time
|
||||
|
||||
__all__ = [
|
||||
|
@ -13,5 +14,6 @@ __all__ = [
|
|||
'check_prerequisites', 'requires_package', 'requires_executable',
|
||||
'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', 'symlink',
|
||||
'scandir', 'FileNotFoundError', 'ProgressBar', 'track_progress',
|
||||
'track_parallel_progress', 'Timer', 'TimerError', 'check_time'
|
||||
'track_iter_progress', 'track_parallel_progress', 'Timer', 'TimerError',
|
||||
'check_time'
|
||||
]
|
||||
|
|
|
@ -172,3 +172,34 @@ def track_parallel_progress(func,
|
|||
pool.close()
|
||||
pool.join()
|
||||
return results
|
||||
|
||||
|
||||
def track_iter_progress(tasks, bar_width=50, **kwargs):
|
||||
"""Track the progress of tasks iteration or enumeration with a progress bar.
|
||||
|
||||
Tasks are yielded with a simple for-loop.
|
||||
|
||||
Args:
|
||||
tasks (list or tuple[Iterable, int]): A list of tasks or
|
||||
(tasks, total num).
|
||||
bar_width (int): Width of progress bar.
|
||||
|
||||
Yields:
|
||||
list: The task results.
|
||||
"""
|
||||
if isinstance(tasks, tuple):
|
||||
assert len(tasks) == 2
|
||||
assert isinstance(tasks[0], collections_abc.Iterable)
|
||||
assert isinstance(tasks[1], int)
|
||||
task_num = tasks[1]
|
||||
tasks = tasks[0]
|
||||
elif isinstance(tasks, collections_abc.Iterable):
|
||||
task_num = len(tasks)
|
||||
else:
|
||||
raise TypeError(
|
||||
'"tasks" must be an iterable object or a (iterator, int) tuple')
|
||||
prog_bar = ProgressBar(task_num, bar_width)
|
||||
for task in tasks:
|
||||
yield task
|
||||
prog_bar.update()
|
||||
sys.stdout.write('\n')
|
||||
|
|
|
@ -81,6 +81,33 @@ def test_track_progress_iterator(capsys):
|
|||
assert ret == [1, 2, 3]
|
||||
|
||||
|
||||
def test_track_iter_progress(capsys):
|
||||
ret = []
|
||||
for num in mmcv.track_iter_progress([1, 2, 3], bar_width=3):
|
||||
ret.append(sleep_1s(num))
|
||||
out, _ = capsys.readouterr()
|
||||
assert out == ('[ ] 0/3, elapsed: 0s, ETA:'
|
||||
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
|
||||
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
|
||||
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
|
||||
assert ret == [1, 2, 3]
|
||||
|
||||
|
||||
def test_track_enum_progress(capsys):
|
||||
ret = []
|
||||
count = []
|
||||
for i, num in enumerate(mmcv.track_iter_progress([1, 2, 3], bar_width=3)):
|
||||
ret.append(sleep_1s(num))
|
||||
count.append(i)
|
||||
out, _ = capsys.readouterr()
|
||||
assert out == ('[ ] 0/3, elapsed: 0s, ETA:'
|
||||
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
|
||||
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
|
||||
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
|
||||
assert ret == [1, 2, 3]
|
||||
assert count == [0, 1, 2]
|
||||
|
||||
|
||||
def test_track_parallel_progress_list(capsys):
|
||||
|
||||
results = mmcv.track_parallel_progress(
|
||||
|
|
Loading…
Reference in New Issue