update cpu training and testing

This commit is contained in:
linyiqi 2022-03-10 16:08:56 +08:00
parent 4c677fa1fa
commit da424f6957
5 changed files with 11 additions and 2 deletions

View File

@ -32,7 +32,7 @@ assert (digit_version(mmcv_minimum_version) <= mmcv_version
f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
mmdet_minimum_version = '2.16.0'
mmdet_maximum_version = '2.21.0'
mmdet_maximum_version = '2.23.0'
mmdet_version = digit_version(mmdet.__version__)

View File

@ -55,6 +55,7 @@ def train_detector(model: nn.Module,
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
# Please use MMCV >= 1.4.4 for CPU training!
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
# build runner

View File

@ -70,7 +70,13 @@ def parse_args():
'--gpu-ids',
type=int,
nargs='+',
help='ids of gpus to use '
help='(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed testing)')
parser.add_argument(
'--gpu-id',
type=int,
default=0,
help='id of gpu to use '
'(only applicable to non-distributed testing)')
parser.add_argument(
'--show_task_results',

View File

@ -195,6 +195,7 @@ def main():
shuffle=False)
if not distributed:
# Please use MMCV >= 1.4.4 for CPU testing!
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
show_kwargs = dict(show_score_thr=args.show_score_thr)
if cfg.data.get('model_init', None) is not None:

View File

@ -142,6 +142,7 @@ def main():
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
rank = 0
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)