mirror of https://github.com/open-mmlab/mmcv.git
parent
f71e47c2f7
commit
f7caa80f9c
|
@ -4,7 +4,8 @@ from .config import Config, ConfigDict, DictAction
|
|||
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
|
||||
import_modules_from_strings, is_list_of, is_seq_of, is_str,
|
||||
is_tuple_of, iter_cast, list_cast, requires_executable,
|
||||
requires_package, slice_list, tuple_cast)
|
||||
requires_package, slice_list, to_1tuple, to_2tuple,
|
||||
to_3tuple, to_4tuple, to_ntuple, tuple_cast)
|
||||
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
|
||||
scandir, symlink)
|
||||
from .progressbar import (ProgressBar, track_iter_progress,
|
||||
|
@ -29,17 +30,18 @@ except ImportError:
|
|||
'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
|
||||
'digit_version', 'get_git_hash', 'import_modules_from_strings',
|
||||
'assert_dict_contains_subset', 'assert_attrs_equal',
|
||||
'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script'
|
||||
'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script',
|
||||
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple'
|
||||
]
|
||||
else:
|
||||
from .env import collect_env
|
||||
from .logging import get_logger, print_log
|
||||
from .parrots_jit import jit, skip_no_elena
|
||||
from .parrots_wrapper import (
|
||||
CUDA_HOME, TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension,
|
||||
DataLoader, PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd,
|
||||
_AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd,
|
||||
_ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config)
|
||||
from .parrots_jit import jit, skip_no_elena
|
||||
from .registry import Registry, build_from_cfg
|
||||
__all__ = [
|
||||
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import collections.abc
|
||||
import functools
|
||||
import itertools
|
||||
import subprocess
|
||||
|
@ -6,6 +7,25 @@ import warnings
|
|||
from collections import abc
|
||||
from importlib import import_module
|
||||
from inspect import getfullargspec
|
||||
from itertools import repeat
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_1tuple = _ntuple(1)
|
||||
to_2tuple = _ntuple(2)
|
||||
to_3tuple = _ntuple(3)
|
||||
to_4tuple = _ntuple(4)
|
||||
to_ntuple = _ntuple
|
||||
|
||||
|
||||
def is_str(x):
|
||||
|
|
|
@ -4,6 +4,31 @@ import pytest
|
|||
import mmcv
|
||||
|
||||
|
||||
def test_to_ntuple():
|
||||
single_number = 2
|
||||
assert mmcv.utils.to_1tuple(single_number) == (single_number, )
|
||||
assert mmcv.utils.to_2tuple(single_number) == (single_number,
|
||||
single_number)
|
||||
assert mmcv.utils.to_3tuple(single_number) == (single_number,
|
||||
single_number,
|
||||
single_number)
|
||||
assert mmcv.utils.to_4tuple(single_number) == (single_number,
|
||||
single_number,
|
||||
single_number,
|
||||
single_number)
|
||||
assert mmcv.utils.to_ntuple(5)(single_number) == (single_number,
|
||||
single_number,
|
||||
single_number,
|
||||
single_number,
|
||||
single_number)
|
||||
assert mmcv.utils.to_ntuple(6)(single_number) == (single_number,
|
||||
single_number,
|
||||
single_number,
|
||||
single_number,
|
||||
single_number,
|
||||
single_number)
|
||||
|
||||
|
||||
def test_iter_cast():
|
||||
assert mmcv.list_cast([1, 2, 3], int) == [1, 2, 3]
|
||||
assert mmcv.list_cast(['1.1', 2, '3'], float) == [1.1, 2.0, 3.0]
|
||||
|
@ -105,6 +130,7 @@ def test_requires_executable(capsys):
|
|||
def test_import_modules_from_strings():
|
||||
# multiple imports
|
||||
import os.path as osp_
|
||||
|
||||
import sys as sys_
|
||||
osp, sys = mmcv.import_modules_from_strings(['os.path', 'sys'])
|
||||
assert osp == osp_
|
||||
|
|
Loading…
Reference in New Issue