update cpu training and testing

This commit is contained in:
linyiqi 2022-03-10 16:08:56 +08:00
parent 4c677fa1fa
commit da424f6957
5 changed files with 11 additions and 2 deletions

View File

@ -32,7 +32,7 @@ assert (digit_version(mmcv_minimum_version) <= mmcv_version
f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.' f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
mmdet_minimum_version = '2.16.0' mmdet_minimum_version = '2.16.0'
mmdet_maximum_version = '2.21.0' mmdet_maximum_version = '2.23.0'
mmdet_version = digit_version(mmdet.__version__) mmdet_version = digit_version(mmdet.__version__)

View File

@ -55,6 +55,7 @@ def train_detector(model: nn.Module,
broadcast_buffers=False, broadcast_buffers=False,
find_unused_parameters=find_unused_parameters) find_unused_parameters=find_unused_parameters)
else: else:
# Please use MMCV >= 1.4.4 for CPU training!
model = MMDataParallel(model, device_ids=cfg.gpu_ids) model = MMDataParallel(model, device_ids=cfg.gpu_ids)
# build runner # build runner

View File

@ -70,7 +70,13 @@ def parse_args():
'--gpu-ids', '--gpu-ids',
type=int, type=int,
nargs='+', nargs='+',
help='ids of gpus to use ' help='(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed testing)')
parser.add_argument(
'--gpu-id',
type=int,
default=0,
help='id of gpu to use '
'(only applicable to non-distributed testing)') '(only applicable to non-distributed testing)')
parser.add_argument( parser.add_argument(
'--show_task_results', '--show_task_results',

View File

@ -195,6 +195,7 @@ def main():
shuffle=False) shuffle=False)
if not distributed: if not distributed:
# Please use MMCV >= 1.4.4 for CPU testing!
model = MMDataParallel(model, device_ids=cfg.gpu_ids) model = MMDataParallel(model, device_ids=cfg.gpu_ids)
show_kwargs = dict(show_score_thr=args.show_score_thr) show_kwargs = dict(show_score_thr=args.show_score_thr)
if cfg.data.get('model_init', None) is not None: if cfg.data.get('model_init', None) is not None:

View File

@ -142,6 +142,7 @@ def main():
# init distributed env first, since logger depends on the dist info. # init distributed env first, since logger depends on the dist info.
if args.launcher == 'none': if args.launcher == 'none':
distributed = False distributed = False
rank = 0
else: else:
distributed = True distributed = True
init_dist(args.launcher, **cfg.dist_params) init_dist(args.launcher, **cfg.dist_params)