[Fix] fix train example (#1502)

* [Fix] fix train example

* [Fix] fix a detail in train example and add warning in MMDP

* [Fix] fix docstring

* [Fix] fix docstring
pull/1463/head^2
Jiazhen Wang 2021-11-23 18:35:26 +08:00 committed by GitHub
parent e85c43ab87
commit b57825f3f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 2 deletions

View File

@ -42,7 +42,9 @@ class Model(nn.Module):
if __name__ == '__main__':
model = Model()
if torch.cuda.is_available():
model = MMDataParallel(model.cuda())
# only use gpu:0 to train
# Solved issue https://github.com/open-mmlab/mmcv/issues/1470
model = MMDataParallel(model.cuda(), device_ids=[0])
# dataset and dataloader
transform = transforms.Compose([

View File

@ -15,6 +15,14 @@ class MMDataParallel(DataParallel):
flexible control of input data during both GPU and CPU inference.
- It implement two more APIs ``train_step()`` and ``val_step()``.
.. warning::
MMDataParallel only supports single GPU training, if you need to
train with multiple GPUs, please use MMDistributedDataParallel
instead. If you have multiple GPUs and you just want to use
MMDataParallel, you can set the environment variable
``CUDA_VISIBLE_DEVICES=0`` or instantiate ``MMDataParallel`` with
``device_ids=[0]``.
Args:
module (:class:`nn.Module`): Module to be encapsulated.
device_ids (list[int]): Device IDS of modules to be scattered to.
@ -54,7 +62,7 @@ class MMDataParallel(DataParallel):
assert len(self.device_ids) == 1, \
('MMDataParallel only supports single GPU training, if you need to'
' train with multiple GPUs, please use MMDistributedDataParallel'
'instead.')
' instead.')
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj: