Fix bug for CPU training (#286)
* remove MMDataParallel when using cpu * support cpu testing * fix lintpull/304/head
parent
c2f01e0dcd
commit
b99bd4fa88
|
@ -87,7 +87,7 @@ def train_model(model,
|
|||
model = MMDataParallel(
|
||||
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
|
||||
elif device == 'cpu':
|
||||
model = MMDataParallel(model.cpu())
|
||||
model = model.cpu()
|
||||
else:
|
||||
raise ValueError(F'unsupported device name {device}.')
|
||||
|
||||
|
|
|
@ -69,6 +69,11 @@ def parse_args():
|
|||
default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
parser.add_argument(
|
||||
'--device',
|
||||
choices=['cpu', 'cuda'],
|
||||
default='cuda',
|
||||
help='device used for testing')
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
@ -122,6 +127,9 @@ def main():
|
|||
CLASSES = ImageNet.CLASSES
|
||||
|
||||
if not distributed:
|
||||
if args.device == 'cpu':
|
||||
model = model.cpu()
|
||||
else:
|
||||
model = MMDataParallel(model, device_ids=[0])
|
||||
model.CLASSES = CLASSES
|
||||
show_kwargs = {} if args.show_options is None else args.show_options
|
||||
|
|
Loading…
Reference in New Issue