mirror of https://github.com/open-mmlab/mmcv.git
Fix mmcv-dataparallel (#497)
* Fix mmcv-dataparallel * Fix (parallel): fix CPU inference with MMDataParallel * Update docstrings * Doc (parallel): refine docstrings * Fix (parallel): fix missing changes of train/val step function * resolve comments * Fix (data_parallel): fix bug when single gpu test return Nonepull/506/head
parent
4ec73abbcc
commit
7a6285b190
|
@ -19,8 +19,13 @@ def scatter(input, devices, streams=None):
|
|||
output = input.contiguous()
|
||||
# TODO: copy to a pinned buffer first (if copying from CPU)
|
||||
stream = streams[0] if output.numel() > 0 else None
|
||||
with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
|
||||
output = output.cuda(devices[0], non_blocking=True)
|
||||
if devices != [-1]:
|
||||
with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
|
||||
output = output.cuda(devices[0], non_blocking=True)
|
||||
else:
|
||||
# unsquzee the first dimension thus the tensor's shape is the
|
||||
# same as those scattered with GPU.
|
||||
output = output.unsqueeze(0)
|
||||
return output
|
||||
else:
|
||||
raise Exception(f'Unknown type {type(input)}.')
|
||||
|
@ -62,7 +67,7 @@ class Scatter:
|
|||
def forward(target_gpus, input):
|
||||
input_device = get_input_device(input)
|
||||
streams = None
|
||||
if input_device == -1:
|
||||
if input_device == -1 and target_gpus != [-1]:
|
||||
# Perform CPU to GPU copies in a background stream
|
||||
streams = [_get_stream(device) for device in target_gpus]
|
||||
|
||||
|
|
|
@ -7,12 +7,48 @@ from .scatter_gather import scatter_kwargs
|
|||
|
||||
|
||||
class MMDataParallel(DataParallel):
|
||||
"""The DataParallel module that supports DataContainer.
|
||||
|
||||
MMDataParallel has two main differences with PyTorch DataParallel:
|
||||
|
||||
- It supports a custom type :class:`DataContainer` which allows more
|
||||
flexible control of input data during both GPU and CPU inference.
|
||||
- It implement two more APIs ``train_step()`` and ``val_step()``.
|
||||
|
||||
Args:
|
||||
module (:class:`nn.Module`): Module to be encapsulated.
|
||||
device_ids (list[int]): Device IDS of modules to be scattered to.
|
||||
Defaults to None when GPU is not available.
|
||||
output_device (str | int): Device ID for output. Defaults to None.
|
||||
dim (int): Dimension used to scatter the data. Defaults to 0.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, dim=0, **kwargs):
|
||||
super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs)
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
"""Override the original forward function.
|
||||
|
||||
The main difference lies in the CPU inference where the datas in
|
||||
:class:`DataContainers` will still be gathered.
|
||||
"""
|
||||
if not self.device_ids:
|
||||
# We add the following line thus the module could gather and
|
||||
# convert data containers as those in GPU inference
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
|
||||
return self.module(*inputs[0], **kwargs[0])
|
||||
else:
|
||||
return super().forward(*inputs, **kwargs)
|
||||
|
||||
def scatter(self, inputs, kwargs, device_ids):
|
||||
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
||||
|
||||
def train_step(self, *inputs, **kwargs):
|
||||
if not self.device_ids:
|
||||
# We add the following line thus the module could gather and
|
||||
# convert data containers as those in GPU inference
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
|
||||
return self.module.train_step(*inputs, **kwargs)
|
||||
|
||||
assert len(self.device_ids) == 1, \
|
||||
|
@ -32,6 +68,9 @@ class MMDataParallel(DataParallel):
|
|||
|
||||
def val_step(self, *inputs, **kwargs):
|
||||
if not self.device_ids:
|
||||
# We add the following line thus the module could gather and
|
||||
# convert data containers as those in GPU inference
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
|
||||
return self.module.val_step(*inputs, **kwargs)
|
||||
|
||||
assert len(self.device_ids) == 1, \
|
||||
|
|
|
@ -15,7 +15,11 @@ def scatter(inputs, target_gpus, dim=0):
|
|||
|
||||
def scatter_map(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return OrigScatter.apply(target_gpus, None, dim, obj)
|
||||
if target_gpus != [-1]:
|
||||
return OrigScatter.apply(target_gpus, None, dim, obj)
|
||||
else:
|
||||
# for CPU inference we use self-implemented scatter
|
||||
return Scatter.forward(target_gpus, obj)
|
||||
if isinstance(obj, DataContainer):
|
||||
if obj.cpu_only:
|
||||
return obj.data
|
||||
|
|
Loading…
Reference in New Issue