mirror of https://github.com/open-mmlab/mmcv.git
parent
6fc6f75a76
commit
57f3a6142f
|
@ -18,6 +18,11 @@ class MMDistributedDataParallel(DistributedDataParallel):
|
|||
- It implement two APIs ``train_step()`` and ``val_step()``.
|
||||
"""
|
||||
|
||||
def to_kwargs(self, inputs, kwargs, device_id):
|
||||
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
|
||||
# to move all tensors to device_id
|
||||
return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
|
||||
|
||||
def scatter(self, inputs, kwargs, device_ids):
|
||||
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
||||
|
||||
|
|
Loading…
Reference in New Issue