mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Update digit_version (#778)
* update digit_version * add unittest * fix import
This commit is contained in:
parent
58f5dbce7d
commit
bfc3cdbdaf
@ -1,4 +1,7 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
|
from packaging.version import parse
|
||||||
|
|
||||||
from .version import __version__, version_info
|
from .version import __version__, version_info
|
||||||
|
|
||||||
@ -6,16 +9,44 @@ MMCV_MIN = '1.3.7'
|
|||||||
MMCV_MAX = '1.4.0'
|
MMCV_MAX = '1.4.0'
|
||||||
|
|
||||||
|
|
||||||
def digit_version(version_str):
|
def digit_version(version_str: str, length: int = 4):
|
||||||
digit_version = []
|
"""Convert a version string into a tuple of integers.
|
||||||
for x in version_str.split('.'):
|
|
||||||
if x.isdigit():
|
This method is usually used for comparing two versions. For pre-release
|
||||||
digit_version.append(int(x))
|
versions: alpha < beta < rc.
|
||||||
elif x.find('rc') != -1:
|
|
||||||
patch_version = x.split('rc')
|
Args:
|
||||||
digit_version.append(int(patch_version[0]) - 1)
|
version_str (str): The version string.
|
||||||
digit_version.append(int(patch_version[1]))
|
length (int): The maximum number of version levels. Default: 4.
|
||||||
return digit_version
|
|
||||||
|
Returns:
|
||||||
|
tuple[int]: The version info in digits (integers).
|
||||||
|
"""
|
||||||
|
version = parse(version_str)
|
||||||
|
assert version.release, f'failed to parse version {version_str}'
|
||||||
|
release = list(version.release)
|
||||||
|
release = release[:length]
|
||||||
|
if len(release) < length:
|
||||||
|
release = release + [0] * (length - len(release))
|
||||||
|
if version.is_prerelease:
|
||||||
|
mapping = {'a': -3, 'b': -2, 'rc': -1}
|
||||||
|
val = -4
|
||||||
|
# version.pre can be None
|
||||||
|
if version.pre:
|
||||||
|
if version.pre[0] not in mapping:
|
||||||
|
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
|
||||||
|
'version checking may go wrong')
|
||||||
|
else:
|
||||||
|
val = mapping[version.pre[0]]
|
||||||
|
release.extend([val, version.pre[-1]])
|
||||||
|
else:
|
||||||
|
release.extend([val, 0])
|
||||||
|
|
||||||
|
elif version.is_postrelease:
|
||||||
|
release.extend([1, version.post])
|
||||||
|
else:
|
||||||
|
release.extend([0, 0])
|
||||||
|
return tuple(release)
|
||||||
|
|
||||||
|
|
||||||
mmcv_min_version = digit_version(MMCV_MIN)
|
mmcv_min_version = digit_version(MMCV_MIN)
|
||||||
@ -27,4 +58,4 @@ assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \
|
|||||||
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
||||||
f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.'
|
f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.'
|
||||||
|
|
||||||
__all__ = ['__version__', 'version_info']
|
__all__ = ['__version__', 'version_info', 'digit_version']
|
||||||
|
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from mmcv.parallel import collate
|
from mmcv.parallel import collate
|
||||||
from mmcv.runner import get_dist_info
|
from mmcv.runner import get_dist_info
|
||||||
from mmcv.utils import Registry, build_from_cfg
|
from mmcv.utils import Registry, build_from_cfg, digit_version
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
|
||||||
if platform.system() != 'Windows':
|
if platform.system() != 'Windows':
|
||||||
@ -133,7 +133,7 @@ def build_dataloader(dataset,
|
|||||||
worker_init_fn, num_workers=num_workers, rank=rank,
|
worker_init_fn, num_workers=num_workers, rank=rank,
|
||||||
seed=seed) if seed is not None else None
|
seed=seed) if seed is not None else None
|
||||||
|
|
||||||
if torch.__version__ >= '1.8.0':
|
if digit_version(torch.__version__) >= digit_version('1.8.0'):
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
matplotlib
|
matplotlib
|
||||||
numpy
|
numpy
|
||||||
|
packaging
|
||||||
prettytable
|
prettytable
|
||||||
|
@ -8,6 +8,6 @@ line_length = 79
|
|||||||
multi_line_output = 0
|
multi_line_output = 0
|
||||||
known_standard_library = setuptools
|
known_standard_library = setuptools
|
||||||
known_first_party = mmseg
|
known_first_party = mmseg
|
||||||
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,prettytable,pytest,scipy,seaborn,torch,ts
|
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,packaging,prettytable,pytest,scipy,seaborn,torch,ts
|
||||||
no_lines_before = STDLIB,LOCALFOLDER
|
no_lines_before = STDLIB,LOCALFOLDER
|
||||||
default_section = THIRDPARTY
|
default_section = THIRDPARTY
|
||||||
|
20
tests/test_digit_version.py
Normal file
20
tests/test_digit_version.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from mmseg import digit_version
|
||||||
|
|
||||||
|
|
||||||
|
def test_digit_version():
|
||||||
|
assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0)
|
||||||
|
assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0)
|
||||||
|
assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0)
|
||||||
|
assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1)
|
||||||
|
assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0)
|
||||||
|
assert digit_version('1.0') == digit_version('1.0.0')
|
||||||
|
assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5')
|
||||||
|
assert digit_version('1.0.0dev') < digit_version('1.0.0a')
|
||||||
|
assert digit_version('1.0.0a') < digit_version('1.0.0a1')
|
||||||
|
assert digit_version('1.0.0a') < digit_version('1.0.0b')
|
||||||
|
assert digit_version('1.0.0b') < digit_version('1.0.0rc')
|
||||||
|
assert digit_version('1.0.0rc1') < digit_version('1.0.0')
|
||||||
|
assert digit_version('1.0.0') < digit_version('1.0.0post')
|
||||||
|
assert digit_version('1.0.0post') < digit_version('1.0.0post1')
|
||||||
|
assert digit_version('v1') == (1, 0, 0, 0, 0, 0)
|
||||||
|
assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0)
|
Loading…
x
Reference in New Issue
Block a user