diff --git a/tools/train.py b/tools/train.py index ac9fe6ed..622041a0 100755 --- a/tools/train.py +++ b/tools/train.py @@ -90,15 +90,15 @@ def main(): if args.resume: cfg.resume = True + # enable automatically scaling LR if args.auto_scale_lr: - if cfg.get('auto_scale_lr'): - cfg.auto_scale_lr = True + if 'auto_scale_lr' in cfg and \ + 'base_batch_size' in cfg.auto_scale_lr: + cfg.auto_scale_lr.enable = True else: - print_log( - 'auto_scale_lr does not exist in your config, ' - 'please set `auto_scale_lr = dict(base_batch_size=xx)', - logger='current', - level=logging.WARNING) + raise RuntimeError('Can not find "auto_scale_lr" or ' + '"auto_scale_lr.base_batch_size" in your' + ' configuration file.') # build the runner from config if 'runner_type' not in cfg: