Fix bug in gpu_ids in distributed training (#107)
* update gpu_ids in distributed training * move linear scaling rule after getting correct gpu_ids * Remove support for autoscale_lrpull/115/head
parent
b1e91f256b
commit
f355f15485
|
@ -7,7 +7,7 @@ import time
|
||||||
import mmcv
|
import mmcv
|
||||||
import torch
|
import torch
|
||||||
from mmcv import Config, DictAction
|
from mmcv import Config, DictAction
|
||||||
from mmcv.runner import init_dist
|
from mmcv.runner import get_dist_info, init_dist
|
||||||
|
|
||||||
from mmcls import __version__
|
from mmcls import __version__
|
||||||
from mmcls.apis import set_random_seed, train_model
|
from mmcls.apis import set_random_seed, train_model
|
||||||
|
@ -51,10 +51,6 @@ def parse_args():
|
||||||
default='none',
|
default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
parser.add_argument(
|
|
||||||
'--autoscale-lr',
|
|
||||||
action='store_true',
|
|
||||||
help='automatically scale lr with the number of gpus')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if 'LOCAL_RANK' not in os.environ:
|
if 'LOCAL_RANK' not in os.environ:
|
||||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||||
|
@ -87,16 +83,14 @@ def main():
|
||||||
else:
|
else:
|
||||||
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
|
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
|
||||||
|
|
||||||
if args.autoscale_lr:
|
|
||||||
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
|
|
||||||
cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
|
|
||||||
|
|
||||||
# init distributed env first, since logger depends on the dist info.
|
# init distributed env first, since logger depends on the dist info.
|
||||||
if args.launcher == 'none':
|
if args.launcher == 'none':
|
||||||
distributed = False
|
distributed = False
|
||||||
else:
|
else:
|
||||||
distributed = True
|
distributed = True
|
||||||
init_dist(args.launcher, **cfg.dist_params)
|
init_dist(args.launcher, **cfg.dist_params)
|
||||||
|
_, world_size = get_dist_info()
|
||||||
|
cfg.gpu_ids = range(world_size)
|
||||||
|
|
||||||
# create work_dir
|
# create work_dir
|
||||||
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
|
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
|
||||||
|
|
Loading…
Reference in New Issue