mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Update digit_version (#778)
* update digit_version * add unittest * fix import
This commit is contained in:
parent
58f5dbce7d
commit
bfc3cdbdaf
@ -1,4 +1,7 @@
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
from packaging.version import parse
|
||||
|
||||
from .version import __version__, version_info
|
||||
|
||||
@ -6,16 +9,44 @@ MMCV_MIN = '1.3.7'
|
||||
MMCV_MAX = '1.4.0'
|
||||
|
||||
|
||||
def digit_version(version_str):
|
||||
digit_version = []
|
||||
for x in version_str.split('.'):
|
||||
if x.isdigit():
|
||||
digit_version.append(int(x))
|
||||
elif x.find('rc') != -1:
|
||||
patch_version = x.split('rc')
|
||||
digit_version.append(int(patch_version[0]) - 1)
|
||||
digit_version.append(int(patch_version[1]))
|
||||
return digit_version
|
||||
def digit_version(version_str: str, length: int = 4):
|
||||
"""Convert a version string into a tuple of integers.
|
||||
|
||||
This method is usually used for comparing two versions. For pre-release
|
||||
versions: alpha < beta < rc.
|
||||
|
||||
Args:
|
||||
version_str (str): The version string.
|
||||
length (int): The maximum number of version levels. Default: 4.
|
||||
|
||||
Returns:
|
||||
tuple[int]: The version info in digits (integers).
|
||||
"""
|
||||
version = parse(version_str)
|
||||
assert version.release, f'failed to parse version {version_str}'
|
||||
release = list(version.release)
|
||||
release = release[:length]
|
||||
if len(release) < length:
|
||||
release = release + [0] * (length - len(release))
|
||||
if version.is_prerelease:
|
||||
mapping = {'a': -3, 'b': -2, 'rc': -1}
|
||||
val = -4
|
||||
# version.pre can be None
|
||||
if version.pre:
|
||||
if version.pre[0] not in mapping:
|
||||
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
|
||||
'version checking may go wrong')
|
||||
else:
|
||||
val = mapping[version.pre[0]]
|
||||
release.extend([val, version.pre[-1]])
|
||||
else:
|
||||
release.extend([val, 0])
|
||||
|
||||
elif version.is_postrelease:
|
||||
release.extend([1, version.post])
|
||||
else:
|
||||
release.extend([0, 0])
|
||||
return tuple(release)
|
||||
|
||||
|
||||
mmcv_min_version = digit_version(MMCV_MIN)
|
||||
@ -27,4 +58,4 @@ assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \
|
||||
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
||||
f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.'
|
||||
|
||||
__all__ = ['__version__', 'version_info']
|
||||
__all__ = ['__version__', 'version_info', 'digit_version']
|
||||
|
@ -7,7 +7,7 @@ import numpy as np
|
||||
import torch
|
||||
from mmcv.parallel import collate
|
||||
from mmcv.runner import get_dist_info
|
||||
from mmcv.utils import Registry, build_from_cfg
|
||||
from mmcv.utils import Registry, build_from_cfg, digit_version
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
if platform.system() != 'Windows':
|
||||
@ -133,7 +133,7 @@ def build_dataloader(dataset,
|
||||
worker_init_fn, num_workers=num_workers, rank=rank,
|
||||
seed=seed) if seed is not None else None
|
||||
|
||||
if torch.__version__ >= '1.8.0':
|
||||
if digit_version(torch.__version__) >= digit_version('1.8.0'):
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
|
@ -1,3 +1,4 @@
|
||||
matplotlib
|
||||
numpy
|
||||
packaging
|
||||
prettytable
|
||||
|
@ -8,6 +8,6 @@ line_length = 79
|
||||
multi_line_output = 0
|
||||
known_standard_library = setuptools
|
||||
known_first_party = mmseg
|
||||
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,prettytable,pytest,scipy,seaborn,torch,ts
|
||||
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,packaging,prettytable,pytest,scipy,seaborn,torch,ts
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
20
tests/test_digit_version.py
Normal file
20
tests/test_digit_version.py
Normal file
@ -0,0 +1,20 @@
|
||||
from mmseg import digit_version
|
||||
|
||||
|
||||
def test_digit_version():
|
||||
assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0)
|
||||
assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0)
|
||||
assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0)
|
||||
assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1)
|
||||
assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0)
|
||||
assert digit_version('1.0') == digit_version('1.0.0')
|
||||
assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5')
|
||||
assert digit_version('1.0.0dev') < digit_version('1.0.0a')
|
||||
assert digit_version('1.0.0a') < digit_version('1.0.0a1')
|
||||
assert digit_version('1.0.0a') < digit_version('1.0.0b')
|
||||
assert digit_version('1.0.0b') < digit_version('1.0.0rc')
|
||||
assert digit_version('1.0.0rc1') < digit_version('1.0.0')
|
||||
assert digit_version('1.0.0') < digit_version('1.0.0post')
|
||||
assert digit_version('1.0.0post') < digit_version('1.0.0post1')
|
||||
assert digit_version('v1') == (1, 0, 0, 0, 0, 0)
|
||||
assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0)
|
Loading…
x
Reference in New Issue
Block a user