mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Migrate utils from mmcv (#447)
This commit is contained in:
parent
5e1ef1dd6c
commit
b75962a660
@ -14,8 +14,13 @@ from .package_utils import (call_command, check_install_package,
|
|||||||
from .parrots_wrapper import TORCH_VERSION
|
from .parrots_wrapper import TORCH_VERSION
|
||||||
from .path import (check_file_exist, fopen, is_abs, is_filepath,
|
from .path import (check_file_exist, fopen, is_abs, is_filepath,
|
||||||
mkdir_or_exist, scandir, symlink)
|
mkdir_or_exist, scandir, symlink)
|
||||||
|
from .progressbar import (ProgressBar, track_iter_progress,
|
||||||
|
track_parallel_progress, track_progress)
|
||||||
from .setup_env import set_multi_processing
|
from .setup_env import set_multi_processing
|
||||||
from .sync_bn import revert_sync_batchnorm
|
from .sync_bn import revert_sync_batchnorm
|
||||||
|
from .timer import Timer, check_time
|
||||||
|
from .torch_ops import torch_meshgrid
|
||||||
|
from .trace import is_jit_tracing
|
||||||
from .version_utils import digit_version, get_git_hash
|
from .version_utils import digit_version, get_git_hash
|
||||||
|
|
||||||
# TODO: creates intractable circular import issues
|
# TODO: creates intractable circular import issues
|
||||||
@ -32,5 +37,8 @@ __all__ = [
|
|||||||
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
|
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
|
||||||
'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm',
|
'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm',
|
||||||
'is_abs', 'is_installed', 'call_command', 'get_installed_path',
|
'is_abs', 'is_installed', 'call_command', 'get_installed_path',
|
||||||
'check_install_package', 'is_abs', 'revert_sync_batchnorm', 'collect_env'
|
'check_install_package', 'is_abs', 'revert_sync_batchnorm', 'collect_env',
|
||||||
|
'Timer', 'check_time', 'ProgressBar', 'track_iter_progress',
|
||||||
|
'track_parallel_progress', 'track_progress', 'torch_meshgrid',
|
||||||
|
'is_jit_tracing'
|
||||||
]
|
]
|
||||||
|
@ -106,3 +106,14 @@ DataLoader, PoolDataLoader = _get_dataloader()
|
|||||||
BuildExtension, CppExtension, CUDAExtension = _get_extension()
|
BuildExtension, CppExtension, CUDAExtension = _get_extension()
|
||||||
_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
|
_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
|
||||||
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
|
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
|
||||||
|
|
||||||
|
|
||||||
|
class SyncBatchNorm(SyncBatchNorm_): # type: ignore
|
||||||
|
|
||||||
|
def _check_input_dim(self, input):
|
||||||
|
if TORCH_VERSION == 'parrots':
|
||||||
|
if input.dim() < 2:
|
||||||
|
raise ValueError(
|
||||||
|
f'expected at least 2D input (got {input.dim()}D input)')
|
||||||
|
else:
|
||||||
|
super()._check_input_dim(input)
|
||||||
|
208
mmengine/utils/progressbar.py
Normal file
208
mmengine/utils/progressbar.py
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import sys
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from multiprocessing import Pool
|
||||||
|
from shutil import get_terminal_size
|
||||||
|
|
||||||
|
from .timer import Timer
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressBar:
|
||||||
|
"""A progress bar which can print the progress."""
|
||||||
|
|
||||||
|
def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
|
||||||
|
self.task_num = task_num
|
||||||
|
self.bar_width = bar_width
|
||||||
|
self.completed = 0
|
||||||
|
self.file = file
|
||||||
|
if start:
|
||||||
|
self.start()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def terminal_width(self):
|
||||||
|
width, _ = get_terminal_size()
|
||||||
|
return width
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
if self.task_num > 0:
|
||||||
|
self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
|
||||||
|
'elapsed: 0s, ETA:')
|
||||||
|
else:
|
||||||
|
self.file.write('completed: 0, elapsed: 0s')
|
||||||
|
self.file.flush()
|
||||||
|
self.timer = Timer()
|
||||||
|
|
||||||
|
def update(self, num_tasks=1):
|
||||||
|
assert num_tasks > 0
|
||||||
|
self.completed += num_tasks
|
||||||
|
elapsed = self.timer.since_start()
|
||||||
|
if elapsed > 0:
|
||||||
|
fps = self.completed / elapsed
|
||||||
|
else:
|
||||||
|
fps = float('inf')
|
||||||
|
if self.task_num > 0:
|
||||||
|
percentage = self.completed / float(self.task_num)
|
||||||
|
eta = int(elapsed * (1 - percentage) / percentage + 0.5)
|
||||||
|
msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
|
||||||
|
f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
|
||||||
|
f'ETA: {eta:5}s'
|
||||||
|
|
||||||
|
bar_width = min(self.bar_width,
|
||||||
|
int(self.terminal_width - len(msg)) + 2,
|
||||||
|
int(self.terminal_width * 0.6))
|
||||||
|
bar_width = max(2, bar_width)
|
||||||
|
mark_width = int(bar_width * percentage)
|
||||||
|
bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
|
||||||
|
self.file.write(msg.format(bar_chars))
|
||||||
|
else:
|
||||||
|
self.file.write(
|
||||||
|
f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
|
||||||
|
f' {fps:.1f} tasks/s')
|
||||||
|
self.file.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
|
||||||
|
"""Track the progress of tasks execution with a progress bar.
|
||||||
|
|
||||||
|
Tasks are done with a simple for-loop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (callable): The function to be applied to each task.
|
||||||
|
tasks (list or tuple[Iterable, int]): A list of tasks or
|
||||||
|
(tasks, total num).
|
||||||
|
bar_width (int): Width of progress bar.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: The task results.
|
||||||
|
"""
|
||||||
|
if isinstance(tasks, tuple):
|
||||||
|
assert len(tasks) == 2
|
||||||
|
assert isinstance(tasks[0], Iterable)
|
||||||
|
assert isinstance(tasks[1], int)
|
||||||
|
task_num = tasks[1]
|
||||||
|
tasks = tasks[0]
|
||||||
|
elif isinstance(tasks, Iterable):
|
||||||
|
task_num = len(tasks)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'"tasks" must be an iterable object or a (iterator, int) tuple')
|
||||||
|
prog_bar = ProgressBar(task_num, bar_width, file=file)
|
||||||
|
results = []
|
||||||
|
for task in tasks:
|
||||||
|
results.append(func(task, **kwargs))
|
||||||
|
prog_bar.update()
|
||||||
|
prog_bar.file.write('\n')
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def init_pool(process_num, initializer=None, initargs=None):
|
||||||
|
if initializer is None:
|
||||||
|
return Pool(process_num)
|
||||||
|
elif initargs is None:
|
||||||
|
return Pool(process_num, initializer)
|
||||||
|
else:
|
||||||
|
if not isinstance(initargs, tuple):
|
||||||
|
raise TypeError('"initargs" must be a tuple')
|
||||||
|
return Pool(process_num, initializer, initargs)
|
||||||
|
|
||||||
|
|
||||||
|
def track_parallel_progress(func,
|
||||||
|
tasks,
|
||||||
|
nproc,
|
||||||
|
initializer=None,
|
||||||
|
initargs=None,
|
||||||
|
bar_width=50,
|
||||||
|
chunksize=1,
|
||||||
|
skip_first=False,
|
||||||
|
keep_order=True,
|
||||||
|
file=sys.stdout):
|
||||||
|
"""Track the progress of parallel task execution with a progress bar.
|
||||||
|
|
||||||
|
The built-in :mod:`multiprocessing` module is used for process pools and
|
||||||
|
tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (callable): The function to be applied to each task.
|
||||||
|
tasks (list or tuple[Iterable, int]): A list of tasks or
|
||||||
|
(tasks, total num).
|
||||||
|
nproc (int): Process (worker) number.
|
||||||
|
initializer (None or callable): Refer to :class:`multiprocessing.Pool`
|
||||||
|
for details.
|
||||||
|
initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for
|
||||||
|
details.
|
||||||
|
chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
|
||||||
|
bar_width (int): Width of progress bar.
|
||||||
|
skip_first (bool): Whether to skip the first sample for each worker
|
||||||
|
when estimating fps, since the initialization step may takes
|
||||||
|
longer.
|
||||||
|
keep_order (bool): If True, :func:`Pool.imap` is used, otherwise
|
||||||
|
:func:`Pool.imap_unordered` is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: The task results.
|
||||||
|
"""
|
||||||
|
if isinstance(tasks, tuple):
|
||||||
|
assert len(tasks) == 2
|
||||||
|
assert isinstance(tasks[0], Iterable)
|
||||||
|
assert isinstance(tasks[1], int)
|
||||||
|
task_num = tasks[1]
|
||||||
|
tasks = tasks[0]
|
||||||
|
elif isinstance(tasks, Iterable):
|
||||||
|
task_num = len(tasks)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'"tasks" must be an iterable object or a (iterator, int) tuple')
|
||||||
|
pool = init_pool(nproc, initializer, initargs)
|
||||||
|
start = not skip_first
|
||||||
|
task_num -= nproc * chunksize * int(skip_first)
|
||||||
|
prog_bar = ProgressBar(task_num, bar_width, start, file=file)
|
||||||
|
results = []
|
||||||
|
if keep_order:
|
||||||
|
gen = pool.imap(func, tasks, chunksize)
|
||||||
|
else:
|
||||||
|
gen = pool.imap_unordered(func, tasks, chunksize)
|
||||||
|
for result in gen:
|
||||||
|
results.append(result)
|
||||||
|
if skip_first:
|
||||||
|
if len(results) < nproc * chunksize:
|
||||||
|
continue
|
||||||
|
elif len(results) == nproc * chunksize:
|
||||||
|
prog_bar.start()
|
||||||
|
continue
|
||||||
|
prog_bar.update()
|
||||||
|
prog_bar.file.write('\n')
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
|
||||||
|
"""Track the progress of tasks iteration or enumeration with a progress
|
||||||
|
bar.
|
||||||
|
|
||||||
|
Tasks are yielded with a simple for-loop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks (list or tuple[Iterable, int]): A list of tasks or
|
||||||
|
(tasks, total num).
|
||||||
|
bar_width (int): Width of progress bar.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
list: The task results.
|
||||||
|
"""
|
||||||
|
if isinstance(tasks, tuple):
|
||||||
|
assert len(tasks) == 2
|
||||||
|
assert isinstance(tasks[0], Iterable)
|
||||||
|
assert isinstance(tasks[1], int)
|
||||||
|
task_num = tasks[1]
|
||||||
|
tasks = tasks[0]
|
||||||
|
elif isinstance(tasks, Iterable):
|
||||||
|
task_num = len(tasks)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'"tasks" must be an iterable object or a (iterator, int) tuple')
|
||||||
|
prog_bar = ProgressBar(task_num, bar_width, file=file)
|
||||||
|
for task in tasks:
|
||||||
|
yield task
|
||||||
|
prog_bar.update()
|
||||||
|
prog_bar.file.write('\n')
|
118
mmengine/utils/timer.py
Normal file
118
mmengine/utils/timer.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
|
||||||
|
class TimerError(Exception):
|
||||||
|
|
||||||
|
def __init__(self, message):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class Timer:
|
||||||
|
"""A flexible Timer class.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import time
|
||||||
|
>>> import mmcv
|
||||||
|
>>> with mmcv.Timer():
|
||||||
|
>>> # simulate a code block that will run for 1s
|
||||||
|
>>> time.sleep(1)
|
||||||
|
1.000
|
||||||
|
>>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'):
|
||||||
|
>>> # simulate a code block that will run for 1s
|
||||||
|
>>> time.sleep(1)
|
||||||
|
it takes 1.0 seconds
|
||||||
|
>>> timer = mmcv.Timer()
|
||||||
|
>>> time.sleep(0.5)
|
||||||
|
>>> print(timer.since_start())
|
||||||
|
0.500
|
||||||
|
>>> time.sleep(0.5)
|
||||||
|
>>> print(timer.since_last_check())
|
||||||
|
0.500
|
||||||
|
>>> print(timer.since_start())
|
||||||
|
1.000
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, start=True, print_tmpl=None):
|
||||||
|
self._is_running = False
|
||||||
|
self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}'
|
||||||
|
if start:
|
||||||
|
self.start()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_running(self):
|
||||||
|
"""bool: indicate whether the timer is running"""
|
||||||
|
return self._is_running
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
print(self.print_tmpl.format(self.since_last_check()))
|
||||||
|
self._is_running = False
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Start the timer."""
|
||||||
|
if not self._is_running:
|
||||||
|
self._t_start = time()
|
||||||
|
self._is_running = True
|
||||||
|
self._t_last = time()
|
||||||
|
|
||||||
|
def since_start(self):
|
||||||
|
"""Total time since the timer is started.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Time in seconds.
|
||||||
|
"""
|
||||||
|
if not self._is_running:
|
||||||
|
raise TimerError('timer is not running')
|
||||||
|
self._t_last = time()
|
||||||
|
return self._t_last - self._t_start
|
||||||
|
|
||||||
|
def since_last_check(self):
|
||||||
|
"""Time since the last checking.
|
||||||
|
|
||||||
|
Either :func:`since_start` or :func:`since_last_check` is a checking
|
||||||
|
operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Time in seconds.
|
||||||
|
"""
|
||||||
|
if not self._is_running:
|
||||||
|
raise TimerError('timer is not running')
|
||||||
|
dur = time() - self._t_last
|
||||||
|
self._t_last = time()
|
||||||
|
return dur
|
||||||
|
|
||||||
|
|
||||||
|
_g_timers = {} # global timers
|
||||||
|
|
||||||
|
|
||||||
|
def check_time(timer_id):
|
||||||
|
"""Add check points in a single line.
|
||||||
|
|
||||||
|
This method is suitable for running a task on a list of items. A timer will
|
||||||
|
be registered when the method is called for the first time.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import time
|
||||||
|
>>> import mmcv
|
||||||
|
>>> for i in range(1, 6):
|
||||||
|
>>> # simulate a code block
|
||||||
|
>>> time.sleep(i)
|
||||||
|
>>> mmcv.check_time('task1')
|
||||||
|
2.000
|
||||||
|
3.000
|
||||||
|
4.000
|
||||||
|
5.000
|
||||||
|
|
||||||
|
Args:
|
||||||
|
str: Timer identifier.
|
||||||
|
"""
|
||||||
|
if timer_id not in _g_timers:
|
||||||
|
_g_timers[timer_id] = Timer()
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return _g_timers[timer_id].since_last_check()
|
29
mmengine/utils/torch_ops.py
Normal file
29
mmengine/utils/torch_ops.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .parrots_wrapper import TORCH_VERSION
|
||||||
|
from .version_utils import digit_version
|
||||||
|
|
||||||
|
_torch_version_meshgrid_indexing = (
|
||||||
|
'parrots' not in TORCH_VERSION
|
||||||
|
and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0'))
|
||||||
|
|
||||||
|
|
||||||
|
def torch_meshgrid(*tensors):
|
||||||
|
"""A wrapper of torch.meshgrid to compat different PyTorch versions.
|
||||||
|
|
||||||
|
Since PyTorch 1.10.0a0, torch.meshgrid supports the arguments ``indexing``.
|
||||||
|
So we implement a wrapper here to avoid warning when using high-version
|
||||||
|
PyTorch and avoid compatibility issues when using previous versions of
|
||||||
|
PyTorch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensors (List[Tensor]): List of scalars or 1 dimensional tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sequence[Tensor]: Sequence of meshgrid tensors.
|
||||||
|
"""
|
||||||
|
if _torch_version_meshgrid_indexing:
|
||||||
|
return torch.meshgrid(*tensors, indexing='ij')
|
||||||
|
else:
|
||||||
|
return torch.meshgrid(*tensors) # Uses indexing='ij' by default
|
24
mmengine/utils/trace.py
Normal file
24
mmengine/utils/trace.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .version_utils import digit_version
|
||||||
|
|
||||||
|
|
||||||
|
def is_jit_tracing() -> bool:
|
||||||
|
if (torch.__version__ != 'parrots'
|
||||||
|
and digit_version(torch.__version__) >= digit_version('1.6.0')):
|
||||||
|
on_trace = torch.jit.is_tracing()
|
||||||
|
# In PyTorch 1.6, torch.jit.is_tracing has a bug.
|
||||||
|
# Refers to https://github.com/pytorch/pytorch/issues/42448
|
||||||
|
if isinstance(on_trace, bool):
|
||||||
|
return on_trace
|
||||||
|
else:
|
||||||
|
return torch._C._is_tracing()
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
'torch.jit.is_tracing is only supported after v1.6.0. '
|
||||||
|
'Therefore is_tracing returns False automatically. Please '
|
||||||
|
'set on_trace manually if you are using trace.', UserWarning)
|
||||||
|
return False
|
163
tests/test_utils/test_progressbar.py
Normal file
163
tests/test_utils/test_progressbar.py
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from io import StringIO
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
|
||||||
|
|
||||||
|
def reset_string_io(io):
|
||||||
|
io.truncate(0)
|
||||||
|
io.seek(0)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProgressBar:
|
||||||
|
|
||||||
|
def test_start(self):
|
||||||
|
out = StringIO()
|
||||||
|
bar_width = 20
|
||||||
|
# without total task num
|
||||||
|
prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out)
|
||||||
|
assert out.getvalue() == 'completed: 0, elapsed: 0s'
|
||||||
|
reset_string_io(out)
|
||||||
|
prog_bar = mmcv.ProgressBar(bar_width=bar_width, start=False, file=out)
|
||||||
|
assert out.getvalue() == ''
|
||||||
|
reset_string_io(out)
|
||||||
|
prog_bar.start()
|
||||||
|
assert out.getvalue() == 'completed: 0, elapsed: 0s'
|
||||||
|
# with total task num
|
||||||
|
reset_string_io(out)
|
||||||
|
prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
|
||||||
|
assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:'
|
||||||
|
reset_string_io(out)
|
||||||
|
prog_bar = mmcv.ProgressBar(
|
||||||
|
10, bar_width=bar_width, start=False, file=out)
|
||||||
|
assert out.getvalue() == ''
|
||||||
|
reset_string_io(out)
|
||||||
|
prog_bar.start()
|
||||||
|
assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:'
|
||||||
|
|
||||||
|
def test_update(self):
|
||||||
|
out = StringIO()
|
||||||
|
bar_width = 20
|
||||||
|
# without total task num
|
||||||
|
prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out)
|
||||||
|
time.sleep(1)
|
||||||
|
reset_string_io(out)
|
||||||
|
prog_bar.update()
|
||||||
|
assert out.getvalue() == 'completed: 1, elapsed: 1s, 1.0 tasks/s'
|
||||||
|
reset_string_io(out)
|
||||||
|
# with total task num
|
||||||
|
prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
|
||||||
|
time.sleep(1)
|
||||||
|
reset_string_io(out)
|
||||||
|
prog_bar.update()
|
||||||
|
assert out.getvalue() == f'\r[{">" * 2 + " " * 18}] 1/10, 1.0 ' \
|
||||||
|
'task/s, elapsed: 1s, ETA: 9s'
|
||||||
|
|
||||||
|
def test_adaptive_length(self):
|
||||||
|
with patch.dict('os.environ', {'COLUMNS': '80'}):
|
||||||
|
out = StringIO()
|
||||||
|
bar_width = 20
|
||||||
|
prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
|
||||||
|
time.sleep(1)
|
||||||
|
reset_string_io(out)
|
||||||
|
prog_bar.update()
|
||||||
|
assert len(out.getvalue()) == 66
|
||||||
|
|
||||||
|
os.environ['COLUMNS'] = '30'
|
||||||
|
reset_string_io(out)
|
||||||
|
prog_bar.update()
|
||||||
|
assert len(out.getvalue()) == 48
|
||||||
|
|
||||||
|
os.environ['COLUMNS'] = '60'
|
||||||
|
reset_string_io(out)
|
||||||
|
prog_bar.update()
|
||||||
|
assert len(out.getvalue()) == 60
|
||||||
|
|
||||||
|
|
||||||
|
def sleep_1s(num):
|
||||||
|
time.sleep(1)
|
||||||
|
return num
|
||||||
|
|
||||||
|
|
||||||
|
def test_track_progress_list():
|
||||||
|
out = StringIO()
|
||||||
|
ret = mmcv.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out)
|
||||||
|
assert out.getvalue() == (
|
||||||
|
'[ ] 0/3, elapsed: 0s, ETA:'
|
||||||
|
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
|
||||||
|
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
|
||||||
|
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
|
||||||
|
assert ret == [1, 2, 3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_track_progress_iterator():
|
||||||
|
out = StringIO()
|
||||||
|
ret = mmcv.track_progress(
|
||||||
|
sleep_1s, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out)
|
||||||
|
assert out.getvalue() == (
|
||||||
|
'[ ] 0/3, elapsed: 0s, ETA:'
|
||||||
|
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
|
||||||
|
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
|
||||||
|
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
|
||||||
|
assert ret == [1, 2, 3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_track_iter_progress():
|
||||||
|
out = StringIO()
|
||||||
|
ret = []
|
||||||
|
for num in mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out):
|
||||||
|
ret.append(sleep_1s(num))
|
||||||
|
assert out.getvalue() == (
|
||||||
|
'[ ] 0/3, elapsed: 0s, ETA:'
|
||||||
|
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
|
||||||
|
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
|
||||||
|
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
|
||||||
|
assert ret == [1, 2, 3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_track_enum_progress():
|
||||||
|
out = StringIO()
|
||||||
|
ret = []
|
||||||
|
count = []
|
||||||
|
for i, num in enumerate(
|
||||||
|
mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out)):
|
||||||
|
ret.append(sleep_1s(num))
|
||||||
|
count.append(i)
|
||||||
|
assert out.getvalue() == (
|
||||||
|
'[ ] 0/3, elapsed: 0s, ETA:'
|
||||||
|
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
|
||||||
|
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
|
||||||
|
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
|
||||||
|
assert ret == [1, 2, 3]
|
||||||
|
assert count == [0, 1, 2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_track_parallel_progress_list():
|
||||||
|
out = StringIO()
|
||||||
|
results = mmcv.track_parallel_progress(
|
||||||
|
sleep_1s, [1, 2, 3, 4], 2, bar_width=4, file=out)
|
||||||
|
# The following cannot pass CI on Github Action
|
||||||
|
# assert out.getvalue() == (
|
||||||
|
# '[ ] 0/4, elapsed: 0s, ETA:'
|
||||||
|
# '\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s'
|
||||||
|
# '\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s'
|
||||||
|
# '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s'
|
||||||
|
# '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n')
|
||||||
|
assert results == [1, 2, 3, 4]
|
||||||
|
|
||||||
|
|
||||||
|
def test_track_parallel_progress_iterator():
|
||||||
|
out = StringIO()
|
||||||
|
results = mmcv.track_parallel_progress(
|
||||||
|
sleep_1s, ((i for i in [1, 2, 3, 4]), 4), 2, bar_width=4, file=out)
|
||||||
|
# The following cannot pass CI on Github Action
|
||||||
|
# assert out.getvalue() == (
|
||||||
|
# '[ ] 0/4, elapsed: 0s, ETA:'
|
||||||
|
# '\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s'
|
||||||
|
# '\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s'
|
||||||
|
# '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s'
|
||||||
|
# '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n')
|
||||||
|
assert results == [1, 2, 3, 4]
|
39
tests/test_utils/test_timer.py
Normal file
39
tests/test_utils/test_timer.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def test_timer_init():
|
||||||
|
timer = mmcv.Timer(start=False)
|
||||||
|
assert not timer.is_running
|
||||||
|
timer.start()
|
||||||
|
assert timer.is_running
|
||||||
|
timer = mmcv.Timer()
|
||||||
|
assert timer.is_running
|
||||||
|
|
||||||
|
|
||||||
|
def test_timer_run():
|
||||||
|
timer = mmcv.Timer()
|
||||||
|
time.sleep(1)
|
||||||
|
assert abs(timer.since_start() - 1) < 1e-2
|
||||||
|
time.sleep(1)
|
||||||
|
assert abs(timer.since_last_check() - 1) < 1e-2
|
||||||
|
assert abs(timer.since_start() - 2) < 1e-2
|
||||||
|
timer = mmcv.Timer(False)
|
||||||
|
with pytest.raises(mmcv.TimerError):
|
||||||
|
timer.since_start()
|
||||||
|
with pytest.raises(mmcv.TimerError):
|
||||||
|
timer.since_last_check()
|
||||||
|
|
||||||
|
|
||||||
|
def test_timer_context(capsys):
|
||||||
|
with mmcv.Timer():
|
||||||
|
time.sleep(1)
|
||||||
|
out, _ = capsys.readouterr()
|
||||||
|
assert abs(float(out) - 1) < 1e-2
|
||||||
|
with mmcv.Timer(print_tmpl='time: {:.1f}s'):
|
||||||
|
time.sleep(1)
|
||||||
|
out, _ = capsys.readouterr()
|
||||||
|
assert out == 'time: 1.0s\n'
|
15
tests/test_utils/test_torch_ops.py
Normal file
15
tests/test_utils/test_torch_ops.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmengine.utils import torch_meshgrid
|
||||||
|
|
||||||
|
|
||||||
|
def test_torch_meshgrid():
|
||||||
|
# torch_meshgrid should not throw warning
|
||||||
|
with pytest.warns(None) as record:
|
||||||
|
x = torch.tensor([1, 2, 3])
|
||||||
|
y = torch.tensor([4, 5, 6])
|
||||||
|
grid_x, grid_y = torch_meshgrid(x, y)
|
||||||
|
|
||||||
|
assert len(record) == 0
|
25
tests/test_utils/test_trace.py
Normal file
25
tests/test_utils/test_trace.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmengine.utils import digit_version, is_jit_tracing
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
digit_version(torch.__version__) < digit_version('1.6.0'),
|
||||||
|
reason='torch.jit.is_tracing is not available before 1.6.0')
|
||||||
|
def test_is_jit_tracing():
|
||||||
|
|
||||||
|
def foo(x):
|
||||||
|
if is_jit_tracing():
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return x.tolist()
|
||||||
|
|
||||||
|
x = torch.rand(3)
|
||||||
|
# test without trace
|
||||||
|
assert isinstance(foo(x), list)
|
||||||
|
|
||||||
|
# test with trace
|
||||||
|
traced_foo = torch.jit.trace(foo, (torch.rand(1), ))
|
||||||
|
assert isinstance(traced_foo(x), torch.Tensor)
|
Loading…
x
Reference in New Issue
Block a user