From 6f6b17e65f64fa21c7e5fc71f8fb28f23b6097f1 Mon Sep 17 00:00:00 2001 From: mattcasey02 <35430698+mattcasey02@users.noreply.github.com> Date: Sun, 17 Apr 2022 22:47:56 -0400 Subject: [PATCH] [Fix] Fixed version comparison to include prerelease versions (#1877) * Fixed version comparison to include prerelease versions Currently all tagged versions of torch 1.11.0 have version 1.11.0a0. Previously the comparison to 1.11.0 failed and self._sync_params() was still used, causing an error. This fix should include all versions of 1.11. * Same update Didn't realize that 1.11.0 was mentioned multiple times in the file. This fixes the other instances. --- mmcv/parallel/distributed.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmcv/parallel/distributed.py b/mmcv/parallel/distributed.py index 791b6c080..0188ca4ab 100644 --- a/mmcv/parallel/distributed.py +++ b/mmcv/parallel/distributed.py @@ -45,7 +45,7 @@ class MMDistributedDataParallel(DistributedDataParallel): logger='mmcv') if ('parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) >= digit_version('1.11.0')): + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): if self._check_sync_bufs_pre_fwd(): self._sync_buffers() else: @@ -65,7 +65,7 @@ class MMDistributedDataParallel(DistributedDataParallel): output = self.module.train_step(*inputs, **kwargs) if ('parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) >= digit_version('1.11.0')): + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): if self._check_sync_bufs_post_fwd(): self._sync_buffers() @@ -100,7 +100,7 @@ class MMDistributedDataParallel(DistributedDataParallel): logger='mmcv') if ('parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) >= digit_version('1.11.0')): + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): if self._check_sync_bufs_pre_fwd(): self._sync_buffers() else: @@ -120,7 +120,7 @@ class MMDistributedDataParallel(DistributedDataParallel): output = self.module.val_step(*inputs, **kwargs) if ('parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) >= digit_version('1.11.0')): + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): if self._check_sync_bufs_post_fwd(): self._sync_buffers()