From bfc3cdbdaf4fa17f446ebfc7adc68c4b2b85a1bb Mon Sep 17 00:00:00 2001 From: Junjun2016 Date: Thu, 12 Aug 2021 14:34:02 +0800 Subject: [PATCH] [Fix] Update digit_version (#778) * update digit_version * add unittest * fix import --- mmseg/__init__.py | 53 +++++++++++++++++++++++++++++-------- mmseg/datasets/builder.py | 4 +-- requirements/runtime.txt | 1 + setup.cfg | 2 +- tests/test_digit_version.py | 20 ++++++++++++++ 5 files changed, 66 insertions(+), 14 deletions(-) create mode 100644 tests/test_digit_version.py diff --git a/mmseg/__init__.py b/mmseg/__init__.py index dbdebf994..317622c92 100644 --- a/mmseg/__init__.py +++ b/mmseg/__init__.py @@ -1,4 +1,7 @@ +import warnings + import mmcv +from packaging.version import parse from .version import __version__, version_info @@ -6,16 +9,44 @@ MMCV_MIN = '1.3.7' MMCV_MAX = '1.4.0' -def digit_version(version_str): - digit_version = [] - for x in version_str.split('.'): - if x.isdigit(): - digit_version.append(int(x)) - elif x.find('rc') != -1: - patch_version = x.split('rc') - digit_version.append(int(patch_version[0]) - 1) - digit_version.append(int(patch_version[1])) - return digit_version +def digit_version(version_str: str, length: int = 4): + """Convert a version string into a tuple of integers. + + This method is usually used for comparing two versions. For pre-release + versions: alpha < beta < rc. + + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Default: 4. + + Returns: + tuple[int]: The version info in digits (integers). + """ + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + mapping = {'a': -3, 'b': -2, 'rc': -1} + val = -4 + # version.pre can be None + if version.pre: + if version.pre[0] not in mapping: + warnings.warn(f'unknown prerelease version {version.pre[0]}, ' + 'version checking may go wrong') + else: + val = mapping[version.pre[0]] + release.extend([val, version.pre[-1]]) + else: + release.extend([val, 0]) + + elif version.is_postrelease: + release.extend([1, version.post]) + else: + release.extend([0, 0]) + return tuple(release) mmcv_min_version = digit_version(MMCV_MIN) @@ -27,4 +58,4 @@ assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \ f'MMCV=={mmcv.__version__} is used but incompatible. ' \ f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.' -__all__ = ['__version__', 'version_info'] +__all__ = ['__version__', 'version_info', 'digit_version'] diff --git a/mmseg/datasets/builder.py b/mmseg/datasets/builder.py index 5994ab233..82f6f460f 100644 --- a/mmseg/datasets/builder.py +++ b/mmseg/datasets/builder.py @@ -7,7 +7,7 @@ import numpy as np import torch from mmcv.parallel import collate from mmcv.runner import get_dist_info -from mmcv.utils import Registry, build_from_cfg +from mmcv.utils import Registry, build_from_cfg, digit_version from torch.utils.data import DataLoader, DistributedSampler if platform.system() != 'Windows': @@ -133,7 +133,7 @@ def build_dataloader(dataset, worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None - if torch.__version__ >= '1.8.0': + if digit_version(torch.__version__) >= digit_version('1.8.0'): data_loader = DataLoader( dataset, batch_size=batch_size, diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 47048d029..2712f504c 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,3 +1,4 @@ matplotlib numpy +packaging prettytable diff --git a/setup.cfg b/setup.cfg index 0dbe479fa..0c80b37ce 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmseg -known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,prettytable,pytest,scipy,seaborn,torch,ts +known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,packaging,prettytable,pytest,scipy,seaborn,torch,ts no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tests/test_digit_version.py b/tests/test_digit_version.py new file mode 100644 index 000000000..4d6649005 --- /dev/null +++ b/tests/test_digit_version.py @@ -0,0 +1,20 @@ +from mmseg import digit_version + + +def test_digit_version(): + assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0) + assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0) + assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0) + assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1) + assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0) + assert digit_version('1.0') == digit_version('1.0.0') + assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5') + assert digit_version('1.0.0dev') < digit_version('1.0.0a') + assert digit_version('1.0.0a') < digit_version('1.0.0a1') + assert digit_version('1.0.0a') < digit_version('1.0.0b') + assert digit_version('1.0.0b') < digit_version('1.0.0rc') + assert digit_version('1.0.0rc1') < digit_version('1.0.0') + assert digit_version('1.0.0') < digit_version('1.0.0post') + assert digit_version('1.0.0post') < digit_version('1.0.0post1') + assert digit_version('v1') == (1, 0, 0, 0, 0, 0) + assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0)