diff --git a/mmcv/parallel/distributed.py b/mmcv/parallel/distributed.py index 791b6c080..0188ca4ab 100644 --- a/mmcv/parallel/distributed.py +++ b/mmcv/parallel/distributed.py @@ -45,7 +45,7 @@ class MMDistributedDataParallel(DistributedDataParallel): logger='mmcv') if ('parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) >= digit_version('1.11.0')): + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): if self._check_sync_bufs_pre_fwd(): self._sync_buffers() else: @@ -65,7 +65,7 @@ class MMDistributedDataParallel(DistributedDataParallel): output = self.module.train_step(*inputs, **kwargs) if ('parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) >= digit_version('1.11.0')): + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): if self._check_sync_bufs_post_fwd(): self._sync_buffers() @@ -100,7 +100,7 @@ class MMDistributedDataParallel(DistributedDataParallel): logger='mmcv') if ('parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) >= digit_version('1.11.0')): + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): if self._check_sync_bufs_pre_fwd(): self._sync_buffers() else: @@ -120,7 +120,7 @@ class MMDistributedDataParallel(DistributedDataParallel): output = self.module.val_step(*inputs, **kwargs) if ('parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) >= digit_version('1.11.0')): + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): if self._check_sync_bufs_post_fwd(): self._sync_buffers()