mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
Fix mmcv-dataparallel (#497)
* Fix mmcv-dataparallel * Fix (parallel): fix CPU inference with MMDataParallel * Update docstrings * Doc (parallel): refine docstrings * Fix (parallel): fix missing changes of train/val step function * resolve comments * Fix (data_parallel): fix bug when single gpu test return None
This commit is contained in:
parent
4ec73abbcc
commit
7a6285b190
@ -19,8 +19,13 @@ def scatter(input, devices, streams=None):
|
|||||||
output = input.contiguous()
|
output = input.contiguous()
|
||||||
# TODO: copy to a pinned buffer first (if copying from CPU)
|
# TODO: copy to a pinned buffer first (if copying from CPU)
|
||||||
stream = streams[0] if output.numel() > 0 else None
|
stream = streams[0] if output.numel() > 0 else None
|
||||||
with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
|
if devices != [-1]:
|
||||||
output = output.cuda(devices[0], non_blocking=True)
|
with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
|
||||||
|
output = output.cuda(devices[0], non_blocking=True)
|
||||||
|
else:
|
||||||
|
# unsquzee the first dimension thus the tensor's shape is the
|
||||||
|
# same as those scattered with GPU.
|
||||||
|
output = output.unsqueeze(0)
|
||||||
return output
|
return output
|
||||||
else:
|
else:
|
||||||
raise Exception(f'Unknown type {type(input)}.')
|
raise Exception(f'Unknown type {type(input)}.')
|
||||||
@ -62,7 +67,7 @@ class Scatter:
|
|||||||
def forward(target_gpus, input):
|
def forward(target_gpus, input):
|
||||||
input_device = get_input_device(input)
|
input_device = get_input_device(input)
|
||||||
streams = None
|
streams = None
|
||||||
if input_device == -1:
|
if input_device == -1 and target_gpus != [-1]:
|
||||||
# Perform CPU to GPU copies in a background stream
|
# Perform CPU to GPU copies in a background stream
|
||||||
streams = [_get_stream(device) for device in target_gpus]
|
streams = [_get_stream(device) for device in target_gpus]
|
||||||
|
|
||||||
|
@ -7,12 +7,48 @@ from .scatter_gather import scatter_kwargs
|
|||||||
|
|
||||||
|
|
||||||
class MMDataParallel(DataParallel):
|
class MMDataParallel(DataParallel):
|
||||||
|
"""The DataParallel module that supports DataContainer.
|
||||||
|
|
||||||
|
MMDataParallel has two main differences with PyTorch DataParallel:
|
||||||
|
|
||||||
|
- It supports a custom type :class:`DataContainer` which allows more
|
||||||
|
flexible control of input data during both GPU and CPU inference.
|
||||||
|
- It implement two more APIs ``train_step()`` and ``val_step()``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (:class:`nn.Module`): Module to be encapsulated.
|
||||||
|
device_ids (list[int]): Device IDS of modules to be scattered to.
|
||||||
|
Defaults to None when GPU is not available.
|
||||||
|
output_device (str | int): Device ID for output. Defaults to None.
|
||||||
|
dim (int): Dimension used to scatter the data. Defaults to 0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, dim=0, **kwargs):
|
||||||
|
super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs)
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, *inputs, **kwargs):
|
||||||
|
"""Override the original forward function.
|
||||||
|
|
||||||
|
The main difference lies in the CPU inference where the datas in
|
||||||
|
:class:`DataContainers` will still be gathered.
|
||||||
|
"""
|
||||||
|
if not self.device_ids:
|
||||||
|
# We add the following line thus the module could gather and
|
||||||
|
# convert data containers as those in GPU inference
|
||||||
|
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
|
||||||
|
return self.module(*inputs[0], **kwargs[0])
|
||||||
|
else:
|
||||||
|
return super().forward(*inputs, **kwargs)
|
||||||
|
|
||||||
def scatter(self, inputs, kwargs, device_ids):
|
def scatter(self, inputs, kwargs, device_ids):
|
||||||
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
||||||
|
|
||||||
def train_step(self, *inputs, **kwargs):
|
def train_step(self, *inputs, **kwargs):
|
||||||
if not self.device_ids:
|
if not self.device_ids:
|
||||||
|
# We add the following line thus the module could gather and
|
||||||
|
# convert data containers as those in GPU inference
|
||||||
|
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
|
||||||
return self.module.train_step(*inputs, **kwargs)
|
return self.module.train_step(*inputs, **kwargs)
|
||||||
|
|
||||||
assert len(self.device_ids) == 1, \
|
assert len(self.device_ids) == 1, \
|
||||||
@ -32,6 +68,9 @@ class MMDataParallel(DataParallel):
|
|||||||
|
|
||||||
def val_step(self, *inputs, **kwargs):
|
def val_step(self, *inputs, **kwargs):
|
||||||
if not self.device_ids:
|
if not self.device_ids:
|
||||||
|
# We add the following line thus the module could gather and
|
||||||
|
# convert data containers as those in GPU inference
|
||||||
|
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
|
||||||
return self.module.val_step(*inputs, **kwargs)
|
return self.module.val_step(*inputs, **kwargs)
|
||||||
|
|
||||||
assert len(self.device_ids) == 1, \
|
assert len(self.device_ids) == 1, \
|
||||||
|
@ -15,7 +15,11 @@ def scatter(inputs, target_gpus, dim=0):
|
|||||||
|
|
||||||
def scatter_map(obj):
|
def scatter_map(obj):
|
||||||
if isinstance(obj, torch.Tensor):
|
if isinstance(obj, torch.Tensor):
|
||||||
return OrigScatter.apply(target_gpus, None, dim, obj)
|
if target_gpus != [-1]:
|
||||||
|
return OrigScatter.apply(target_gpus, None, dim, obj)
|
||||||
|
else:
|
||||||
|
# for CPU inference we use self-implemented scatter
|
||||||
|
return Scatter.forward(target_gpus, obj)
|
||||||
if isinstance(obj, DataContainer):
|
if isinstance(obj, DataContainer):
|
||||||
if obj.cpu_only:
|
if obj.cpu_only:
|
||||||
return obj.data
|
return obj.data
|
||||||
|
Loading…
x
Reference in New Issue
Block a user