This commit is contained in:
xieenze 2020-09-09 16:56:55 +08:00
parent 4cd6eb6cba
commit d2284b373f

View File

@ -248,9 +248,12 @@ def _non_dist_train(model,
seed=cfg.seed,
drop_last=getattr(cfg.data, 'drop_last', False)) for ds in dataset
]
if 'use_fp16' in cfg and cfg.use_fp16 == True:
raise NotImplementedError('apex do not support non_dist_trian!')
# put model on gpus
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
runner = Runner(