mirror of https://github.com/open-mmlab/mmcv.git
[Fix] fix train example (#1502)
* [Fix] fix train example * [Fix] fix a detail in train example and add warning in MMDP * [Fix] fix docstring * [Fix] fix docstringpull/1463/head^2
parent
e85c43ab87
commit
b57825f3f1
|
@ -42,7 +42,9 @@ class Model(nn.Module):
|
|||
if __name__ == '__main__':
|
||||
model = Model()
|
||||
if torch.cuda.is_available():
|
||||
model = MMDataParallel(model.cuda())
|
||||
# only use gpu:0 to train
|
||||
# Solved issue https://github.com/open-mmlab/mmcv/issues/1470
|
||||
model = MMDataParallel(model.cuda(), device_ids=[0])
|
||||
|
||||
# dataset and dataloader
|
||||
transform = transforms.Compose([
|
||||
|
|
|
@ -15,6 +15,14 @@ class MMDataParallel(DataParallel):
|
|||
flexible control of input data during both GPU and CPU inference.
|
||||
- It implement two more APIs ``train_step()`` and ``val_step()``.
|
||||
|
||||
.. warning::
|
||||
MMDataParallel only supports single GPU training, if you need to
|
||||
train with multiple GPUs, please use MMDistributedDataParallel
|
||||
instead. If you have multiple GPUs and you just want to use
|
||||
MMDataParallel, you can set the environment variable
|
||||
``CUDA_VISIBLE_DEVICES=0`` or instantiate ``MMDataParallel`` with
|
||||
``device_ids=[0]``.
|
||||
|
||||
Args:
|
||||
module (:class:`nn.Module`): Module to be encapsulated.
|
||||
device_ids (list[int]): Device IDS of modules to be scattered to.
|
||||
|
@ -54,7 +62,7 @@ class MMDataParallel(DataParallel):
|
|||
assert len(self.device_ids) == 1, \
|
||||
('MMDataParallel only supports single GPU training, if you need to'
|
||||
' train with multiple GPUs, please use MMDistributedDataParallel'
|
||||
'instead.')
|
||||
' instead.')
|
||||
|
||||
for t in chain(self.module.parameters(), self.module.buffers()):
|
||||
if t.device != self.src_device_obj:
|
||||
|
|
Loading…
Reference in New Issue