mirror of https://github.com/open-mmlab/mmcv.git
make compatible with pytorch 0.4.1 and 1.0
parent
076cdd6c74
commit
b64d43873d
|
@ -33,7 +33,10 @@ class MMDistributedDataParallel(nn.Module):
|
|||
self._dist_broadcast_coalesced(module_states,
|
||||
self.broadcast_bucket_size)
|
||||
if self.broadcast_buffers:
|
||||
buffers = [b.data for b in self.module.buffers()]
|
||||
if torch.__version__ < '1.0':
|
||||
buffers = [b.data for b in self.module._all_buffers()]
|
||||
else:
|
||||
buffers = [b.data for b in self.module.buffers()]
|
||||
if len(buffers) > 0:
|
||||
self._dist_broadcast_coalesced(buffers,
|
||||
self.broadcast_bucket_size)
|
||||
|
|
|
@ -5,6 +5,7 @@ from getpass import getuser
|
|||
from socket import gethostname
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
|
@ -13,7 +14,11 @@ def get_host_info():
|
|||
|
||||
|
||||
def get_dist_info():
|
||||
if dist.is_initialized():
|
||||
if torch.__version__ < '1.0':
|
||||
initialized = dist._initialized
|
||||
else:
|
||||
initialized = dist.is_initialized()
|
||||
if initialized:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = '0.2.2'
|
||||
__version__ = '0.2.3'
|
||||
|
|
Loading…
Reference in New Issue