make compatible with pytorch 0.4.1 and 1.0

pull/28/head
Kai Chen 2018-12-25 15:13:44 +08:00
parent 076cdd6c74
commit b64d43873d
3 changed files with 11 additions and 3 deletions

View File

@ -33,7 +33,10 @@ class MMDistributedDataParallel(nn.Module):
self._dist_broadcast_coalesced(module_states,
self.broadcast_bucket_size)
if self.broadcast_buffers:
buffers = [b.data for b in self.module.buffers()]
if torch.__version__ < '1.0':
buffers = [b.data for b in self.module._all_buffers()]
else:
buffers = [b.data for b in self.module.buffers()]
if len(buffers) > 0:
self._dist_broadcast_coalesced(buffers,
self.broadcast_bucket_size)

View File

@ -5,6 +5,7 @@ from getpass import getuser
from socket import gethostname
import mmcv
import torch
import torch.distributed as dist
@ -13,7 +14,11 @@ def get_host_info():
def get_dist_info():
if dist.is_initialized():
if torch.__version__ < '1.0':
initialized = dist._initialized
else:
initialized = dist.is_initialized()
if initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else:

View File

@ -1 +1 @@
__version__ = '0.2.2'
__version__ = '0.2.3'