[Fix] Fixed version comparison to include prerelease versions (#1877)

* Fixed version comparison to include prerelease versions

Currently all tagged versions of torch 1.11.0 have version 1.11.0a0. Previously the comparison to 1.11.0 failed and self._sync_params() was still used, causing an error. This fix should include all versions of 1.11.

* Same update

Didn't realize that 1.11.0 was mentioned multiple times in the file. This fixes the other instances.
This commit is contained in:
mattcasey02 2022-04-17 22:47:56 -04:00 committed by GitHub
parent 5221a3883c
commit 6f6b17e65f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -45,7 +45,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
logger='mmcv')
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0')):
and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_pre_fwd():
self._sync_buffers()
else:
@ -65,7 +65,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
output = self.module.train_step(*inputs, **kwargs)
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0')):
and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_post_fwd():
self._sync_buffers()
@ -100,7 +100,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
logger='mmcv')
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0')):
and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_pre_fwd():
self._sync_buffers()
else:
@ -120,7 +120,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
output = self.module.val_step(*inputs, **kwargs)
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0')):
and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_post_fwd():
self._sync_buffers()