Fix bug for CPU training (#286)

* remove MMDataParallel when using cpu

* support cpu testing

* fix lint
pull/304/head
WRH 2021-06-12 22:26:33 +08:00 committed by GitHub
parent c2f01e0dcd
commit b99bd4fa88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 2 deletions

View File

@ -87,7 +87,7 @@ def train_model(model,
model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
elif device == 'cpu':
model = MMDataParallel(model.cpu())
model = model.cpu()
else:
raise ValueError(F'unsupported device name {device}.')

View File

@ -69,6 +69,11 @@ def parse_args():
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--device',
choices=['cpu', 'cuda'],
default='cuda',
help='device used for testing')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
@ -122,7 +127,10 @@ def main():
CLASSES = ImageNet.CLASSES
if not distributed:
model = MMDataParallel(model, device_ids=[0])
if args.device == 'cpu':
model = model.cpu()
else:
model = MMDataParallel(model, device_ids=[0])
model.CLASSES = CLASSES
show_kwargs = {} if args.show_options is None else args.show_options
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,