Track iter progress (#112)

* track progress of iter&enum

* reformat

* reformat with yapf

* add unitest

* add doc, and deprecate track_enum_progress

* update docs & comments
pull/117/head
ZwwWayne 2019-08-20 15:05:17 +08:00 committed by Kai Chen
parent 9de7927a69
commit d5865e0cbd
4 changed files with 81 additions and 2 deletions

View File

@ -54,6 +54,25 @@ mmcv.track_parallel_progress(func, tasks, 8) # 8 workers
![progress](_static/parallel_progress.gif)
If you want to iterate or enumerate a list of items and track the progress, `track_iter_progress`
is a good choice. It will display a progress bar to tell the progress and ETA.
```python
import mmcv
tasks = [item_1, item_2, ..., item_n]
for task in mmcv.track_iter_progress(tasks):
# do something like print
print(task)
for i, task in enumerate(mmcv.track_iter_progress(tasks)):
# do something like print
print(i)
print(task)
```
### Timer

View File

@ -4,7 +4,8 @@ from .misc import (is_str, iter_cast, list_cast, tuple_cast, is_seq_of,
check_prerequisites, requires_package, requires_executable)
from .path import (is_filepath, fopen, check_file_exist, mkdir_or_exist,
symlink, scandir, FileNotFoundError)
from .progressbar import ProgressBar, track_progress, track_parallel_progress
from .progressbar import (ProgressBar, track_progress, track_parallel_progress,
track_iter_progress)
from .timer import Timer, TimerError, check_time
__all__ = [
@ -13,5 +14,6 @@ __all__ = [
'check_prerequisites', 'requires_package', 'requires_executable',
'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', 'symlink',
'scandir', 'FileNotFoundError', 'ProgressBar', 'track_progress',
'track_parallel_progress', 'Timer', 'TimerError', 'check_time'
'track_iter_progress', 'track_parallel_progress', 'Timer', 'TimerError',
'check_time'
]

View File

@ -172,3 +172,34 @@ def track_parallel_progress(func,
pool.close()
pool.join()
return results
def track_iter_progress(tasks, bar_width=50, **kwargs):
"""Track the progress of tasks iteration or enumeration with a progress bar.
Tasks are yielded with a simple for-loop.
Args:
tasks (list or tuple[Iterable, int]): A list of tasks or
(tasks, total num).
bar_width (int): Width of progress bar.
Yields:
list: The task results.
"""
if isinstance(tasks, tuple):
assert len(tasks) == 2
assert isinstance(tasks[0], collections_abc.Iterable)
assert isinstance(tasks[1], int)
task_num = tasks[1]
tasks = tasks[0]
elif isinstance(tasks, collections_abc.Iterable):
task_num = len(tasks)
else:
raise TypeError(
'"tasks" must be an iterable object or a (iterator, int) tuple')
prog_bar = ProgressBar(task_num, bar_width)
for task in tasks:
yield task
prog_bar.update()
sys.stdout.write('\n')

View File

@ -81,6 +81,33 @@ def test_track_progress_iterator(capsys):
assert ret == [1, 2, 3]
def test_track_iter_progress(capsys):
ret = []
for num in mmcv.track_iter_progress([1, 2, 3], bar_width=3):
ret.append(sleep_1s(num))
out, _ = capsys.readouterr()
assert out == ('[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
assert ret == [1, 2, 3]
def test_track_enum_progress(capsys):
ret = []
count = []
for i, num in enumerate(mmcv.track_iter_progress([1, 2, 3], bar_width=3)):
ret.append(sleep_1s(num))
count.append(i)
out, _ = capsys.readouterr()
assert out == ('[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
assert ret == [1, 2, 3]
assert count == [0, 1, 2]
def test_track_parallel_progress_list(capsys):
results = mmcv.track_parallel_progress(