Fix mmcv-dataparallel (#497)

* Fix mmcv-dataparallel

* Fix (parallel): fix CPU inference with MMDataParallel

* Update docstrings

* Doc (parallel): refine docstrings

* Fix (parallel): fix missing changes of train/val step function

* resolve comments

* Fix (data_parallel): fix bug when single gpu test return None
pull/506/head
Wenwei Zhang 2020-08-19 14:22:40 +08:00 committed by GitHub
parent 4ec73abbcc
commit 7a6285b190
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 52 additions and 4 deletions

View File

@ -19,8 +19,13 @@ def scatter(input, devices, streams=None):
output = input.contiguous()
# TODO: copy to a pinned buffer first (if copying from CPU)
stream = streams[0] if output.numel() > 0 else None
with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
output = output.cuda(devices[0], non_blocking=True)
if devices != [-1]:
with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
output = output.cuda(devices[0], non_blocking=True)
else:
# unsquzee the first dimension thus the tensor's shape is the
# same as those scattered with GPU.
output = output.unsqueeze(0)
return output
else:
raise Exception(f'Unknown type {type(input)}.')
@ -62,7 +67,7 @@ class Scatter:
def forward(target_gpus, input):
input_device = get_input_device(input)
streams = None
if input_device == -1:
if input_device == -1 and target_gpus != [-1]:
# Perform CPU to GPU copies in a background stream
streams = [_get_stream(device) for device in target_gpus]

View File

@ -7,12 +7,48 @@ from .scatter_gather import scatter_kwargs
class MMDataParallel(DataParallel):
"""The DataParallel module that supports DataContainer.
MMDataParallel has two main differences with PyTorch DataParallel:
- It supports a custom type :class:`DataContainer` which allows more
flexible control of input data during both GPU and CPU inference.
- It implement two more APIs ``train_step()`` and ``val_step()``.
Args:
module (:class:`nn.Module`): Module to be encapsulated.
device_ids (list[int]): Device IDS of modules to be scattered to.
Defaults to None when GPU is not available.
output_device (str | int): Device ID for output. Defaults to None.
dim (int): Dimension used to scatter the data. Defaults to 0.
"""
def __init__(self, *args, dim=0, **kwargs):
super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs)
self.dim = dim
def forward(self, *inputs, **kwargs):
"""Override the original forward function.
The main difference lies in the CPU inference where the datas in
:class:`DataContainers` will still be gathered.
"""
if not self.device_ids:
# We add the following line thus the module could gather and
# convert data containers as those in GPU inference
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
return self.module(*inputs[0], **kwargs[0])
else:
return super().forward(*inputs, **kwargs)
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def train_step(self, *inputs, **kwargs):
if not self.device_ids:
# We add the following line thus the module could gather and
# convert data containers as those in GPU inference
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
return self.module.train_step(*inputs, **kwargs)
assert len(self.device_ids) == 1, \
@ -32,6 +68,9 @@ class MMDataParallel(DataParallel):
def val_step(self, *inputs, **kwargs):
if not self.device_ids:
# We add the following line thus the module could gather and
# convert data containers as those in GPU inference
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
return self.module.val_step(*inputs, **kwargs)
assert len(self.device_ids) == 1, \

View File

@ -15,7 +15,11 @@ def scatter(inputs, target_gpus, dim=0):
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
return OrigScatter.apply(target_gpus, None, dim, obj)
if target_gpus != [-1]:
return OrigScatter.apply(target_gpus, None, dim, obj)
else:
# for CPU inference we use self-implemented scatter
return Scatter.forward(target_gpus, obj)
if isinstance(obj, DataContainer):
if obj.cpu_only:
return obj.data