[Enhancement] Add to_ntuple (#1125)

* add to_ntuple

* add unit test
pull/1121/head
Junjun2016 2021-06-23 10:19:28 +08:00 committed by GitHub
parent f71e47c2f7
commit f7caa80f9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 51 additions and 3 deletions

View File

@ -4,7 +4,8 @@ from .config import Config, ConfigDict, DictAction
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
import_modules_from_strings, is_list_of, is_seq_of, is_str,
is_tuple_of, iter_cast, list_cast, requires_executable,
requires_package, slice_list, tuple_cast)
requires_package, slice_list, to_1tuple, to_2tuple,
to_3tuple, to_4tuple, to_ntuple, tuple_cast)
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
scandir, symlink)
from .progressbar import (ProgressBar, track_iter_progress,
@ -29,17 +30,18 @@ except ImportError:
'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
'digit_version', 'get_git_hash', 'import_modules_from_strings',
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script'
'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script',
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple'
]
else:
from .env import collect_env
from .logging import get_logger, print_log
from .parrots_jit import jit, skip_no_elena
from .parrots_wrapper import (
CUDA_HOME, TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension,
DataLoader, PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd,
_ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config)
from .parrots_jit import jit, skip_no_elena
from .registry import Registry, build_from_cfg
__all__ = [
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',

View File

@ -1,4 +1,5 @@
# Copyright (c) Open-MMLab. All rights reserved.
import collections.abc
import functools
import itertools
import subprocess
@ -6,6 +7,25 @@ import warnings
from collections import abc
from importlib import import_module
from inspect import getfullargspec
from itertools import repeat
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
def is_str(x):

View File

@ -4,6 +4,31 @@ import pytest
import mmcv
def test_to_ntuple():
single_number = 2
assert mmcv.utils.to_1tuple(single_number) == (single_number, )
assert mmcv.utils.to_2tuple(single_number) == (single_number,
single_number)
assert mmcv.utils.to_3tuple(single_number) == (single_number,
single_number,
single_number)
assert mmcv.utils.to_4tuple(single_number) == (single_number,
single_number,
single_number,
single_number)
assert mmcv.utils.to_ntuple(5)(single_number) == (single_number,
single_number,
single_number,
single_number,
single_number)
assert mmcv.utils.to_ntuple(6)(single_number) == (single_number,
single_number,
single_number,
single_number,
single_number,
single_number)
def test_iter_cast():
assert mmcv.list_cast([1, 2, 3], int) == [1, 2, 3]
assert mmcv.list_cast(['1.1', 2, '3'], float) == [1.1, 2.0, 3.0]
@ -105,6 +130,7 @@ def test_requires_executable(capsys):
def test_import_modules_from_strings():
# multiple imports
import os.path as osp_
import sys as sys_
osp, sys = mmcv.import_modules_from_strings(['os.path', 'sys'])
assert osp == osp_