From b57825f3f1f38cf6033e0d6d36af644cc57465fa Mon Sep 17 00:00:00 2001 From: Jiazhen Wang <47851024+teamwong111@users.noreply.github.com> Date: Tue, 23 Nov 2021 18:35:26 +0800 Subject: [PATCH] [Fix] fix train example (#1502) * [Fix] fix train example * [Fix] fix a detail in train example and add warning in MMDP * [Fix] fix docstring * [Fix] fix docstring --- examples/train.py | 4 +++- mmcv/parallel/data_parallel.py | 10 +++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/examples/train.py b/examples/train.py index d4588f7e7..2dbdfee40 100644 --- a/examples/train.py +++ b/examples/train.py @@ -42,7 +42,9 @@ class Model(nn.Module): if __name__ == '__main__': model = Model() if torch.cuda.is_available(): - model = MMDataParallel(model.cuda()) + # only use gpu:0 to train + # Solved issue https://github.com/open-mmlab/mmcv/issues/1470 + model = MMDataParallel(model.cuda(), device_ids=[0]) # dataset and dataloader transform = transforms.Compose([ diff --git a/mmcv/parallel/data_parallel.py b/mmcv/parallel/data_parallel.py index 79b5f69b6..7a5abeb6e 100644 --- a/mmcv/parallel/data_parallel.py +++ b/mmcv/parallel/data_parallel.py @@ -15,6 +15,14 @@ class MMDataParallel(DataParallel): flexible control of input data during both GPU and CPU inference. - It implement two more APIs ``train_step()`` and ``val_step()``. + .. warning:: + MMDataParallel only supports single GPU training, if you need to + train with multiple GPUs, please use MMDistributedDataParallel + instead. If you have multiple GPUs and you just want to use + MMDataParallel, you can set the environment variable + ``CUDA_VISIBLE_DEVICES=0`` or instantiate ``MMDataParallel`` with + ``device_ids=[0]``. + Args: module (:class:`nn.Module`): Module to be encapsulated. device_ids (list[int]): Device IDS of modules to be scattered to. @@ -54,7 +62,7 @@ class MMDataParallel(DataParallel): assert len(self.device_ids) == 1, \ ('MMDataParallel only supports single GPU training, if you need to' ' train with multiple GPUs, please use MMDistributedDataParallel' - 'instead.') + ' instead.') for t in chain(self.module.parameters(), self.module.buffers()): if t.device != self.src_device_obj: