[Enhance] Update `digit_version` function. (#402)
* Update digit_version * Add digit_version unit testpull/409/head
parent
f8f1700860
commit
c0a4916a06
|
@ -1,18 +1,49 @@
|
|||
import warnings
|
||||
|
||||
import mmcv
|
||||
from packaging.version import parse
|
||||
|
||||
from .version import __version__
|
||||
|
||||
|
||||
def digit_version(version_str):
|
||||
digit_version = []
|
||||
for x in version_str.split('.'):
|
||||
if x.isdigit():
|
||||
digit_version.append(int(x))
|
||||
elif x.find('rc') != -1:
|
||||
patch_version = x.split('rc')
|
||||
digit_version.append(int(patch_version[0]) - 1)
|
||||
digit_version.append(int(patch_version[1]))
|
||||
return digit_version
|
||||
def digit_version(version_str: str, length: int = 4):
|
||||
"""Convert a version string into a tuple of integers.
|
||||
|
||||
This method is usually used for comparing two versions. For pre-release
|
||||
versions: alpha < beta < rc.
|
||||
|
||||
Args:
|
||||
version_str (str): The version string.
|
||||
length (int): The maximum number of version levels. Default: 4.
|
||||
|
||||
Returns:
|
||||
tuple[int]: The version info in digits (integers).
|
||||
"""
|
||||
version = parse(version_str)
|
||||
assert version.release, f'failed to parse version {version_str}'
|
||||
release = list(version.release)
|
||||
release = release[:length]
|
||||
if len(release) < length:
|
||||
release = release + [0] * (length - len(release))
|
||||
if version.is_prerelease:
|
||||
mapping = {'a': -3, 'b': -2, 'rc': -1}
|
||||
val = -4
|
||||
# version.pre can be None
|
||||
if version.pre:
|
||||
if version.pre[0] not in mapping:
|
||||
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
|
||||
'version checking may go wrong')
|
||||
else:
|
||||
val = mapping[version.pre[0]]
|
||||
release.extend([val, version.pre[-1]])
|
||||
else:
|
||||
release.extend([val, 0])
|
||||
|
||||
elif version.is_postrelease:
|
||||
release.extend([1, version.post])
|
||||
else:
|
||||
release.extend([0, 0])
|
||||
return tuple(release)
|
||||
|
||||
|
||||
mmcv_minimum_version = '1.3.8'
|
||||
|
@ -25,4 +56,4 @@ assert (mmcv_version >= digit_version(mmcv_minimum_version)
|
|||
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
||||
f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
|
||||
|
||||
__all__ = ['__version__']
|
||||
__all__ = ['__version__', 'digit_version']
|
||||
|
|
|
@ -14,6 +14,6 @@ line_length = 79
|
|||
multi_line_output = 0
|
||||
known_standard_library = pkg_resources,setuptools
|
||||
known_first_party = mmcls
|
||||
known_third_party = PIL,cv2,matplotlib,mmcv,mmdet,numpy,onnxruntime,pytest,seaborn,torch,torchvision,ts
|
||||
known_third_party = PIL,cv2,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,seaborn,torch,torchvision,ts
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
from mmcls import digit_version
|
||||
|
||||
|
||||
def test_digit_version():
|
||||
assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0)
|
||||
assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0)
|
||||
assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0)
|
||||
assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1)
|
||||
assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0)
|
||||
assert digit_version('1.0') == digit_version('1.0.0')
|
||||
assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5')
|
||||
assert digit_version('1.0.0dev') < digit_version('1.0.0a')
|
||||
assert digit_version('1.0.0a') < digit_version('1.0.0a1')
|
||||
assert digit_version('1.0.0a') < digit_version('1.0.0b')
|
||||
assert digit_version('1.0.0b') < digit_version('1.0.0rc')
|
||||
assert digit_version('1.0.0rc1') < digit_version('1.0.0')
|
||||
assert digit_version('1.0.0') < digit_version('1.0.0post')
|
||||
assert digit_version('1.0.0post') < digit_version('1.0.0post1')
|
||||
assert digit_version('v1') == (1, 0, 0, 0, 0, 0)
|
||||
assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0)
|
Loading…
Reference in New Issue