41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Optional
|
|
|
|
import mmengine.dist as dist
|
|
import rich.progress as progress
|
|
from rich.live import Live
|
|
|
|
disable_progress_bar = False
|
|
global_progress = progress.Progress(
|
|
'{task.description}',
|
|
progress.BarColumn(),
|
|
progress.TaskProgressColumn(show_speed=True),
|
|
progress.TimeRemainingColumn(),
|
|
)
|
|
global_live = Live(global_progress, refresh_per_second=10)
|
|
|
|
|
|
def track(sequence, description: str = '', total: Optional[float] = None):
|
|
if disable_progress_bar:
|
|
yield from sequence
|
|
else:
|
|
global_live.start()
|
|
task_id = global_progress.add_task(description, total=total)
|
|
task = global_progress._tasks[task_id]
|
|
try:
|
|
yield from global_progress.track(sequence, task_id=task_id)
|
|
finally:
|
|
if task.total is None:
|
|
global_progress.update(task_id, total=task.completed)
|
|
if all(task.finished for task in global_progress.tasks):
|
|
global_live.stop()
|
|
for task_id in global_progress.task_ids:
|
|
global_progress.remove_task(task_id)
|
|
|
|
|
|
def track_on_main_process(sequence, description='', total=None):
|
|
if not dist.is_main_process() or disable_progress_bar:
|
|
yield from sequence
|
|
else:
|
|
yield from track(sequence, total=total, description=description)
|