From b99bd4fa88c5822583ef131ba010269aba8102a7 Mon Sep 17 00:00:00 2001
From: WRH <12756472+wangruohui@users.noreply.github.com>
Date: Sat, 12 Jun 2021 22:26:33 +0800
Subject: [PATCH] Fix bug for CPU training (#286)

* remove MMDataParallel when using cpu

* support cpu testing

* fix lint
---
 mmcls/apis/train.py |  2 +-
 tools/test.py       | 10 +++++++++-
 2 files changed, 10 insertions(+), 2 deletions(-)

diff --git a/mmcls/apis/train.py b/mmcls/apis/train.py
index 94c97897..42b9a5b4 100644
--- a/mmcls/apis/train.py
+++ b/mmcls/apis/train.py
@@ -87,7 +87,7 @@ def train_model(model,
             model = MMDataParallel(
                 model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
         elif device == 'cpu':
-            model = MMDataParallel(model.cpu())
+            model = model.cpu()
         else:
             raise ValueError(F'unsupported device name {device}.')
 
diff --git a/tools/test.py b/tools/test.py
index e5adadc3..0676ad24 100644
--- a/tools/test.py
+++ b/tools/test.py
@@ -69,6 +69,11 @@ def parse_args():
         default='none',
         help='job launcher')
     parser.add_argument('--local_rank', type=int, default=0)
+    parser.add_argument(
+        '--device',
+        choices=['cpu', 'cuda'],
+        default='cuda',
+        help='device used for testing')
     args = parser.parse_args()
     if 'LOCAL_RANK' not in os.environ:
         os.environ['LOCAL_RANK'] = str(args.local_rank)
@@ -122,7 +127,10 @@ def main():
         CLASSES = ImageNet.CLASSES
 
     if not distributed:
-        model = MMDataParallel(model, device_ids=[0])
+        if args.device == 'cpu':
+            model = model.cpu()
+        else:
+            model = MMDataParallel(model, device_ids=[0])
         model.CLASSES = CLASSES
         show_kwargs = {} if args.show_options is None else args.show_options
         outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,