mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Migrate utils from mmcv (#447)
This commit is contained in:
parent
5e1ef1dd6c
commit
b75962a660
@ -14,8 +14,13 @@ from .package_utils import (call_command, check_install_package,
|
||||
from .parrots_wrapper import TORCH_VERSION
|
||||
from .path import (check_file_exist, fopen, is_abs, is_filepath,
|
||||
mkdir_or_exist, scandir, symlink)
|
||||
from .progressbar import (ProgressBar, track_iter_progress,
|
||||
track_parallel_progress, track_progress)
|
||||
from .setup_env import set_multi_processing
|
||||
from .sync_bn import revert_sync_batchnorm
|
||||
from .timer import Timer, check_time
|
||||
from .torch_ops import torch_meshgrid
|
||||
from .trace import is_jit_tracing
|
||||
from .version_utils import digit_version, get_git_hash
|
||||
|
||||
# TODO: creates intractable circular import issues
|
||||
@ -32,5 +37,8 @@ __all__ = [
|
||||
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
|
||||
'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm',
|
||||
'is_abs', 'is_installed', 'call_command', 'get_installed_path',
|
||||
'check_install_package', 'is_abs', 'revert_sync_batchnorm', 'collect_env'
|
||||
'check_install_package', 'is_abs', 'revert_sync_batchnorm', 'collect_env',
|
||||
'Timer', 'check_time', 'ProgressBar', 'track_iter_progress',
|
||||
'track_parallel_progress', 'track_progress', 'torch_meshgrid',
|
||||
'is_jit_tracing'
|
||||
]
|
||||
|
@ -106,3 +106,14 @@ DataLoader, PoolDataLoader = _get_dataloader()
|
||||
BuildExtension, CppExtension, CUDAExtension = _get_extension()
|
||||
_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
|
||||
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
|
||||
|
||||
|
||||
class SyncBatchNorm(SyncBatchNorm_): # type: ignore
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if TORCH_VERSION == 'parrots':
|
||||
if input.dim() < 2:
|
||||
raise ValueError(
|
||||
f'expected at least 2D input (got {input.dim()}D input)')
|
||||
else:
|
||||
super()._check_input_dim(input)
|
||||
|
208
mmengine/utils/progressbar.py
Normal file
208
mmengine/utils/progressbar.py
Normal file
@ -0,0 +1,208 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from multiprocessing import Pool
|
||||
from shutil import get_terminal_size
|
||||
|
||||
from .timer import Timer
|
||||
|
||||
|
||||
class ProgressBar:
|
||||
"""A progress bar which can print the progress."""
|
||||
|
||||
def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
|
||||
self.task_num = task_num
|
||||
self.bar_width = bar_width
|
||||
self.completed = 0
|
||||
self.file = file
|
||||
if start:
|
||||
self.start()
|
||||
|
||||
@property
|
||||
def terminal_width(self):
|
||||
width, _ = get_terminal_size()
|
||||
return width
|
||||
|
||||
def start(self):
|
||||
if self.task_num > 0:
|
||||
self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
|
||||
'elapsed: 0s, ETA:')
|
||||
else:
|
||||
self.file.write('completed: 0, elapsed: 0s')
|
||||
self.file.flush()
|
||||
self.timer = Timer()
|
||||
|
||||
def update(self, num_tasks=1):
|
||||
assert num_tasks > 0
|
||||
self.completed += num_tasks
|
||||
elapsed = self.timer.since_start()
|
||||
if elapsed > 0:
|
||||
fps = self.completed / elapsed
|
||||
else:
|
||||
fps = float('inf')
|
||||
if self.task_num > 0:
|
||||
percentage = self.completed / float(self.task_num)
|
||||
eta = int(elapsed * (1 - percentage) / percentage + 0.5)
|
||||
msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
|
||||
f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
|
||||
f'ETA: {eta:5}s'
|
||||
|
||||
bar_width = min(self.bar_width,
|
||||
int(self.terminal_width - len(msg)) + 2,
|
||||
int(self.terminal_width * 0.6))
|
||||
bar_width = max(2, bar_width)
|
||||
mark_width = int(bar_width * percentage)
|
||||
bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
|
||||
self.file.write(msg.format(bar_chars))
|
||||
else:
|
||||
self.file.write(
|
||||
f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
|
||||
f' {fps:.1f} tasks/s')
|
||||
self.file.flush()
|
||||
|
||||
|
||||
def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
|
||||
"""Track the progress of tasks execution with a progress bar.
|
||||
|
||||
Tasks are done with a simple for-loop.
|
||||
|
||||
Args:
|
||||
func (callable): The function to be applied to each task.
|
||||
tasks (list or tuple[Iterable, int]): A list of tasks or
|
||||
(tasks, total num).
|
||||
bar_width (int): Width of progress bar.
|
||||
|
||||
Returns:
|
||||
list: The task results.
|
||||
"""
|
||||
if isinstance(tasks, tuple):
|
||||
assert len(tasks) == 2
|
||||
assert isinstance(tasks[0], Iterable)
|
||||
assert isinstance(tasks[1], int)
|
||||
task_num = tasks[1]
|
||||
tasks = tasks[0]
|
||||
elif isinstance(tasks, Iterable):
|
||||
task_num = len(tasks)
|
||||
else:
|
||||
raise TypeError(
|
||||
'"tasks" must be an iterable object or a (iterator, int) tuple')
|
||||
prog_bar = ProgressBar(task_num, bar_width, file=file)
|
||||
results = []
|
||||
for task in tasks:
|
||||
results.append(func(task, **kwargs))
|
||||
prog_bar.update()
|
||||
prog_bar.file.write('\n')
|
||||
return results
|
||||
|
||||
|
||||
def init_pool(process_num, initializer=None, initargs=None):
|
||||
if initializer is None:
|
||||
return Pool(process_num)
|
||||
elif initargs is None:
|
||||
return Pool(process_num, initializer)
|
||||
else:
|
||||
if not isinstance(initargs, tuple):
|
||||
raise TypeError('"initargs" must be a tuple')
|
||||
return Pool(process_num, initializer, initargs)
|
||||
|
||||
|
||||
def track_parallel_progress(func,
|
||||
tasks,
|
||||
nproc,
|
||||
initializer=None,
|
||||
initargs=None,
|
||||
bar_width=50,
|
||||
chunksize=1,
|
||||
skip_first=False,
|
||||
keep_order=True,
|
||||
file=sys.stdout):
|
||||
"""Track the progress of parallel task execution with a progress bar.
|
||||
|
||||
The built-in :mod:`multiprocessing` module is used for process pools and
|
||||
tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
|
||||
|
||||
Args:
|
||||
func (callable): The function to be applied to each task.
|
||||
tasks (list or tuple[Iterable, int]): A list of tasks or
|
||||
(tasks, total num).
|
||||
nproc (int): Process (worker) number.
|
||||
initializer (None or callable): Refer to :class:`multiprocessing.Pool`
|
||||
for details.
|
||||
initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for
|
||||
details.
|
||||
chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
|
||||
bar_width (int): Width of progress bar.
|
||||
skip_first (bool): Whether to skip the first sample for each worker
|
||||
when estimating fps, since the initialization step may takes
|
||||
longer.
|
||||
keep_order (bool): If True, :func:`Pool.imap` is used, otherwise
|
||||
:func:`Pool.imap_unordered` is used.
|
||||
|
||||
Returns:
|
||||
list: The task results.
|
||||
"""
|
||||
if isinstance(tasks, tuple):
|
||||
assert len(tasks) == 2
|
||||
assert isinstance(tasks[0], Iterable)
|
||||
assert isinstance(tasks[1], int)
|
||||
task_num = tasks[1]
|
||||
tasks = tasks[0]
|
||||
elif isinstance(tasks, Iterable):
|
||||
task_num = len(tasks)
|
||||
else:
|
||||
raise TypeError(
|
||||
'"tasks" must be an iterable object or a (iterator, int) tuple')
|
||||
pool = init_pool(nproc, initializer, initargs)
|
||||
start = not skip_first
|
||||
task_num -= nproc * chunksize * int(skip_first)
|
||||
prog_bar = ProgressBar(task_num, bar_width, start, file=file)
|
||||
results = []
|
||||
if keep_order:
|
||||
gen = pool.imap(func, tasks, chunksize)
|
||||
else:
|
||||
gen = pool.imap_unordered(func, tasks, chunksize)
|
||||
for result in gen:
|
||||
results.append(result)
|
||||
if skip_first:
|
||||
if len(results) < nproc * chunksize:
|
||||
continue
|
||||
elif len(results) == nproc * chunksize:
|
||||
prog_bar.start()
|
||||
continue
|
||||
prog_bar.update()
|
||||
prog_bar.file.write('\n')
|
||||
pool.close()
|
||||
pool.join()
|
||||
return results
|
||||
|
||||
|
||||
def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
|
||||
"""Track the progress of tasks iteration or enumeration with a progress
|
||||
bar.
|
||||
|
||||
Tasks are yielded with a simple for-loop.
|
||||
|
||||
Args:
|
||||
tasks (list or tuple[Iterable, int]): A list of tasks or
|
||||
(tasks, total num).
|
||||
bar_width (int): Width of progress bar.
|
||||
|
||||
Yields:
|
||||
list: The task results.
|
||||
"""
|
||||
if isinstance(tasks, tuple):
|
||||
assert len(tasks) == 2
|
||||
assert isinstance(tasks[0], Iterable)
|
||||
assert isinstance(tasks[1], int)
|
||||
task_num = tasks[1]
|
||||
tasks = tasks[0]
|
||||
elif isinstance(tasks, Iterable):
|
||||
task_num = len(tasks)
|
||||
else:
|
||||
raise TypeError(
|
||||
'"tasks" must be an iterable object or a (iterator, int) tuple')
|
||||
prog_bar = ProgressBar(task_num, bar_width, file=file)
|
||||
for task in tasks:
|
||||
yield task
|
||||
prog_bar.update()
|
||||
prog_bar.file.write('\n')
|
118
mmengine/utils/timer.py
Normal file
118
mmengine/utils/timer.py
Normal file
@ -0,0 +1,118 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from time import time
|
||||
|
||||
|
||||
class TimerError(Exception):
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class Timer:
|
||||
"""A flexible Timer class.
|
||||
|
||||
Examples:
|
||||
>>> import time
|
||||
>>> import mmcv
|
||||
>>> with mmcv.Timer():
|
||||
>>> # simulate a code block that will run for 1s
|
||||
>>> time.sleep(1)
|
||||
1.000
|
||||
>>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'):
|
||||
>>> # simulate a code block that will run for 1s
|
||||
>>> time.sleep(1)
|
||||
it takes 1.0 seconds
|
||||
>>> timer = mmcv.Timer()
|
||||
>>> time.sleep(0.5)
|
||||
>>> print(timer.since_start())
|
||||
0.500
|
||||
>>> time.sleep(0.5)
|
||||
>>> print(timer.since_last_check())
|
||||
0.500
|
||||
>>> print(timer.since_start())
|
||||
1.000
|
||||
"""
|
||||
|
||||
def __init__(self, start=True, print_tmpl=None):
|
||||
self._is_running = False
|
||||
self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}'
|
||||
if start:
|
||||
self.start()
|
||||
|
||||
@property
|
||||
def is_running(self):
|
||||
"""bool: indicate whether the timer is running"""
|
||||
return self._is_running
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
print(self.print_tmpl.format(self.since_last_check()))
|
||||
self._is_running = False
|
||||
|
||||
def start(self):
|
||||
"""Start the timer."""
|
||||
if not self._is_running:
|
||||
self._t_start = time()
|
||||
self._is_running = True
|
||||
self._t_last = time()
|
||||
|
||||
def since_start(self):
|
||||
"""Total time since the timer is started.
|
||||
|
||||
Returns:
|
||||
float: Time in seconds.
|
||||
"""
|
||||
if not self._is_running:
|
||||
raise TimerError('timer is not running')
|
||||
self._t_last = time()
|
||||
return self._t_last - self._t_start
|
||||
|
||||
def since_last_check(self):
|
||||
"""Time since the last checking.
|
||||
|
||||
Either :func:`since_start` or :func:`since_last_check` is a checking
|
||||
operation.
|
||||
|
||||
Returns:
|
||||
float: Time in seconds.
|
||||
"""
|
||||
if not self._is_running:
|
||||
raise TimerError('timer is not running')
|
||||
dur = time() - self._t_last
|
||||
self._t_last = time()
|
||||
return dur
|
||||
|
||||
|
||||
_g_timers = {} # global timers
|
||||
|
||||
|
||||
def check_time(timer_id):
|
||||
"""Add check points in a single line.
|
||||
|
||||
This method is suitable for running a task on a list of items. A timer will
|
||||
be registered when the method is called for the first time.
|
||||
|
||||
Examples:
|
||||
>>> import time
|
||||
>>> import mmcv
|
||||
>>> for i in range(1, 6):
|
||||
>>> # simulate a code block
|
||||
>>> time.sleep(i)
|
||||
>>> mmcv.check_time('task1')
|
||||
2.000
|
||||
3.000
|
||||
4.000
|
||||
5.000
|
||||
|
||||
Args:
|
||||
str: Timer identifier.
|
||||
"""
|
||||
if timer_id not in _g_timers:
|
||||
_g_timers[timer_id] = Timer()
|
||||
return 0
|
||||
else:
|
||||
return _g_timers[timer_id].since_last_check()
|
29
mmengine/utils/torch_ops.py
Normal file
29
mmengine/utils/torch_ops.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from .parrots_wrapper import TORCH_VERSION
|
||||
from .version_utils import digit_version
|
||||
|
||||
_torch_version_meshgrid_indexing = (
|
||||
'parrots' not in TORCH_VERSION
|
||||
and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0'))
|
||||
|
||||
|
||||
def torch_meshgrid(*tensors):
|
||||
"""A wrapper of torch.meshgrid to compat different PyTorch versions.
|
||||
|
||||
Since PyTorch 1.10.0a0, torch.meshgrid supports the arguments ``indexing``.
|
||||
So we implement a wrapper here to avoid warning when using high-version
|
||||
PyTorch and avoid compatibility issues when using previous versions of
|
||||
PyTorch.
|
||||
|
||||
Args:
|
||||
tensors (List[Tensor]): List of scalars or 1 dimensional tensors.
|
||||
|
||||
Returns:
|
||||
Sequence[Tensor]: Sequence of meshgrid tensors.
|
||||
"""
|
||||
if _torch_version_meshgrid_indexing:
|
||||
return torch.meshgrid(*tensors, indexing='ij')
|
||||
else:
|
||||
return torch.meshgrid(*tensors) # Uses indexing='ij' by default
|
24
mmengine/utils/trace.py
Normal file
24
mmengine/utils/trace.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from .version_utils import digit_version
|
||||
|
||||
|
||||
def is_jit_tracing() -> bool:
|
||||
if (torch.__version__ != 'parrots'
|
||||
and digit_version(torch.__version__) >= digit_version('1.6.0')):
|
||||
on_trace = torch.jit.is_tracing()
|
||||
# In PyTorch 1.6, torch.jit.is_tracing has a bug.
|
||||
# Refers to https://github.com/pytorch/pytorch/issues/42448
|
||||
if isinstance(on_trace, bool):
|
||||
return on_trace
|
||||
else:
|
||||
return torch._C._is_tracing()
|
||||
else:
|
||||
warnings.warn(
|
||||
'torch.jit.is_tracing is only supported after v1.6.0. '
|
||||
'Therefore is_tracing returns False automatically. Please '
|
||||
'set on_trace manually if you are using trace.', UserWarning)
|
||||
return False
|
163
tests/test_utils/test_progressbar.py
Normal file
163
tests/test_utils/test_progressbar.py
Normal file
@ -0,0 +1,163 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import time
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
|
||||
import mmcv
|
||||
|
||||
|
||||
def reset_string_io(io):
|
||||
io.truncate(0)
|
||||
io.seek(0)
|
||||
|
||||
|
||||
class TestProgressBar:
|
||||
|
||||
def test_start(self):
|
||||
out = StringIO()
|
||||
bar_width = 20
|
||||
# without total task num
|
||||
prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out)
|
||||
assert out.getvalue() == 'completed: 0, elapsed: 0s'
|
||||
reset_string_io(out)
|
||||
prog_bar = mmcv.ProgressBar(bar_width=bar_width, start=False, file=out)
|
||||
assert out.getvalue() == ''
|
||||
reset_string_io(out)
|
||||
prog_bar.start()
|
||||
assert out.getvalue() == 'completed: 0, elapsed: 0s'
|
||||
# with total task num
|
||||
reset_string_io(out)
|
||||
prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
|
||||
assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:'
|
||||
reset_string_io(out)
|
||||
prog_bar = mmcv.ProgressBar(
|
||||
10, bar_width=bar_width, start=False, file=out)
|
||||
assert out.getvalue() == ''
|
||||
reset_string_io(out)
|
||||
prog_bar.start()
|
||||
assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:'
|
||||
|
||||
def test_update(self):
|
||||
out = StringIO()
|
||||
bar_width = 20
|
||||
# without total task num
|
||||
prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out)
|
||||
time.sleep(1)
|
||||
reset_string_io(out)
|
||||
prog_bar.update()
|
||||
assert out.getvalue() == 'completed: 1, elapsed: 1s, 1.0 tasks/s'
|
||||
reset_string_io(out)
|
||||
# with total task num
|
||||
prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
|
||||
time.sleep(1)
|
||||
reset_string_io(out)
|
||||
prog_bar.update()
|
||||
assert out.getvalue() == f'\r[{">" * 2 + " " * 18}] 1/10, 1.0 ' \
|
||||
'task/s, elapsed: 1s, ETA: 9s'
|
||||
|
||||
def test_adaptive_length(self):
|
||||
with patch.dict('os.environ', {'COLUMNS': '80'}):
|
||||
out = StringIO()
|
||||
bar_width = 20
|
||||
prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
|
||||
time.sleep(1)
|
||||
reset_string_io(out)
|
||||
prog_bar.update()
|
||||
assert len(out.getvalue()) == 66
|
||||
|
||||
os.environ['COLUMNS'] = '30'
|
||||
reset_string_io(out)
|
||||
prog_bar.update()
|
||||
assert len(out.getvalue()) == 48
|
||||
|
||||
os.environ['COLUMNS'] = '60'
|
||||
reset_string_io(out)
|
||||
prog_bar.update()
|
||||
assert len(out.getvalue()) == 60
|
||||
|
||||
|
||||
def sleep_1s(num):
|
||||
time.sleep(1)
|
||||
return num
|
||||
|
||||
|
||||
def test_track_progress_list():
|
||||
out = StringIO()
|
||||
ret = mmcv.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out)
|
||||
assert out.getvalue() == (
|
||||
'[ ] 0/3, elapsed: 0s, ETA:'
|
||||
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
|
||||
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
|
||||
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
|
||||
assert ret == [1, 2, 3]
|
||||
|
||||
|
||||
def test_track_progress_iterator():
|
||||
out = StringIO()
|
||||
ret = mmcv.track_progress(
|
||||
sleep_1s, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out)
|
||||
assert out.getvalue() == (
|
||||
'[ ] 0/3, elapsed: 0s, ETA:'
|
||||
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
|
||||
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
|
||||
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
|
||||
assert ret == [1, 2, 3]
|
||||
|
||||
|
||||
def test_track_iter_progress():
|
||||
out = StringIO()
|
||||
ret = []
|
||||
for num in mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out):
|
||||
ret.append(sleep_1s(num))
|
||||
assert out.getvalue() == (
|
||||
'[ ] 0/3, elapsed: 0s, ETA:'
|
||||
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
|
||||
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
|
||||
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
|
||||
assert ret == [1, 2, 3]
|
||||
|
||||
|
||||
def test_track_enum_progress():
|
||||
out = StringIO()
|
||||
ret = []
|
||||
count = []
|
||||
for i, num in enumerate(
|
||||
mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out)):
|
||||
ret.append(sleep_1s(num))
|
||||
count.append(i)
|
||||
assert out.getvalue() == (
|
||||
'[ ] 0/3, elapsed: 0s, ETA:'
|
||||
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
|
||||
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
|
||||
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
|
||||
assert ret == [1, 2, 3]
|
||||
assert count == [0, 1, 2]
|
||||
|
||||
|
||||
def test_track_parallel_progress_list():
|
||||
out = StringIO()
|
||||
results = mmcv.track_parallel_progress(
|
||||
sleep_1s, [1, 2, 3, 4], 2, bar_width=4, file=out)
|
||||
# The following cannot pass CI on Github Action
|
||||
# assert out.getvalue() == (
|
||||
# '[ ] 0/4, elapsed: 0s, ETA:'
|
||||
# '\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s'
|
||||
# '\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s'
|
||||
# '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s'
|
||||
# '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n')
|
||||
assert results == [1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_track_parallel_progress_iterator():
|
||||
out = StringIO()
|
||||
results = mmcv.track_parallel_progress(
|
||||
sleep_1s, ((i for i in [1, 2, 3, 4]), 4), 2, bar_width=4, file=out)
|
||||
# The following cannot pass CI on Github Action
|
||||
# assert out.getvalue() == (
|
||||
# '[ ] 0/4, elapsed: 0s, ETA:'
|
||||
# '\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s'
|
||||
# '\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s'
|
||||
# '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s'
|
||||
# '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n')
|
||||
assert results == [1, 2, 3, 4]
|
39
tests/test_utils/test_timer.py
Normal file
39
tests/test_utils/test_timer.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import time
|
||||
|
||||
import mmcv
|
||||
import pytest
|
||||
|
||||
|
||||
def test_timer_init():
|
||||
timer = mmcv.Timer(start=False)
|
||||
assert not timer.is_running
|
||||
timer.start()
|
||||
assert timer.is_running
|
||||
timer = mmcv.Timer()
|
||||
assert timer.is_running
|
||||
|
||||
|
||||
def test_timer_run():
|
||||
timer = mmcv.Timer()
|
||||
time.sleep(1)
|
||||
assert abs(timer.since_start() - 1) < 1e-2
|
||||
time.sleep(1)
|
||||
assert abs(timer.since_last_check() - 1) < 1e-2
|
||||
assert abs(timer.since_start() - 2) < 1e-2
|
||||
timer = mmcv.Timer(False)
|
||||
with pytest.raises(mmcv.TimerError):
|
||||
timer.since_start()
|
||||
with pytest.raises(mmcv.TimerError):
|
||||
timer.since_last_check()
|
||||
|
||||
|
||||
def test_timer_context(capsys):
|
||||
with mmcv.Timer():
|
||||
time.sleep(1)
|
||||
out, _ = capsys.readouterr()
|
||||
assert abs(float(out) - 1) < 1e-2
|
||||
with mmcv.Timer(print_tmpl='time: {:.1f}s'):
|
||||
time.sleep(1)
|
||||
out, _ = capsys.readouterr()
|
||||
assert out == 'time: 1.0s\n'
|
15
tests/test_utils/test_torch_ops.py
Normal file
15
tests/test_utils/test_torch_ops.py
Normal file
@ -0,0 +1,15 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.utils import torch_meshgrid
|
||||
|
||||
|
||||
def test_torch_meshgrid():
|
||||
# torch_meshgrid should not throw warning
|
||||
with pytest.warns(None) as record:
|
||||
x = torch.tensor([1, 2, 3])
|
||||
y = torch.tensor([4, 5, 6])
|
||||
grid_x, grid_y = torch_meshgrid(x, y)
|
||||
|
||||
assert len(record) == 0
|
25
tests/test_utils/test_trace.py
Normal file
25
tests/test_utils/test_trace.py
Normal file
@ -0,0 +1,25 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.utils import digit_version, is_jit_tracing
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
digit_version(torch.__version__) < digit_version('1.6.0'),
|
||||
reason='torch.jit.is_tracing is not available before 1.6.0')
|
||||
def test_is_jit_tracing():
|
||||
|
||||
def foo(x):
|
||||
if is_jit_tracing():
|
||||
return x
|
||||
else:
|
||||
return x.tolist()
|
||||
|
||||
x = torch.rand(3)
|
||||
# test without trace
|
||||
assert isinstance(foo(x), list)
|
||||
|
||||
# test with trace
|
||||
traced_foo = torch.jit.trace(foo, (torch.rand(1), ))
|
||||
assert isinstance(traced_foo(x), torch.Tensor)
|
Loading…
x
Reference in New Issue
Block a user