fix scatter in pytorch18 (#882)

* fix scatter in pytorch18

* remove blanks
pull/898/head
ZhangShilong 2021-03-11 13:13:31 +08:00 committed by GitHub
parent 6fc6f75a76
commit 57f3a6142f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 0 deletions

View File

@ -18,6 +18,11 @@ class MMDistributedDataParallel(DistributedDataParallel):
- It implement two APIs ``train_step()`` and ``val_step()``.
"""
def to_kwargs(self, inputs, kwargs, device_id):
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
# to move all tensors to device_id
return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)