From b99bd4fa88c5822583ef131ba010269aba8102a7 Mon Sep 17 00:00:00 2001 From: WRH <12756472+wangruohui@users.noreply.github.com> Date: Sat, 12 Jun 2021 22:26:33 +0800 Subject: [PATCH] Fix bug for CPU training (#286) * remove MMDataParallel when using cpu * support cpu testing * fix lint --- mmcls/apis/train.py | 2 +- tools/test.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/mmcls/apis/train.py b/mmcls/apis/train.py index 94c97897..42b9a5b4 100644 --- a/mmcls/apis/train.py +++ b/mmcls/apis/train.py @@ -87,7 +87,7 @@ def train_model(model, model = MMDataParallel( model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) elif device == 'cpu': - model = MMDataParallel(model.cpu()) + model = model.cpu() else: raise ValueError(F'unsupported device name {device}.') diff --git a/tools/test.py b/tools/test.py index e5adadc3..0676ad24 100644 --- a/tools/test.py +++ b/tools/test.py @@ -69,6 +69,11 @@ def parse_args(): default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--device', + choices=['cpu', 'cuda'], + default='cuda', + help='device used for testing') args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) @@ -122,7 +127,10 @@ def main(): CLASSES = ImageNet.CLASSES if not distributed: - model = MMDataParallel(model, device_ids=[0]) + if args.device == 'cpu': + model = model.cpu() + else: + model = MMDataParallel(model, device_ids=[0]) model.CLASSES = CLASSES show_kwargs = {} if args.show_options is None else args.show_options outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,