diff --git a/easycv/apis/train.py b/easycv/apis/train.py index 1c19c927..cceb635f 100644 --- a/easycv/apis/train.py +++ b/easycv/apis/train.py @@ -167,12 +167,15 @@ def train_model(model, runner.register_hook(DistSamplerSeedHook()) # register eval hooks - validate = False - if 'eval_pipelines' in cfg: - if isinstance(cfg.eval_pipelines, dict): - cfg.eval_pipelines = [cfg.eval_pipelines] - if len(cfg.eval_pipelines) > 0: - validate = True + if validate: + if 'eval_pipelines' not in cfg: + runner.logger.warning( + 'Not find `eval_pipelines` in cfg, skip validation!') + validate = False + else: + if isinstance(cfg.eval_pipelines, dict): + cfg.eval_pipelines = [cfg.eval_pipelines] + assert len(cfg.eval_pipelines) > 0 runner.logger.info('open validate hook') best_metric_name = [ diff --git a/tools/train.py b/tools/train.py index 55d0b438..8ba5ce59 100644 --- a/tools/train.py +++ b/tools/train.py @@ -57,6 +57,10 @@ def parse_args(): parser.add_argument('--load_from', help='the checkpoint file to load from') parser.add_argument( '--pretrained', default=None, help='pretrained model file') + parser.add_argument( + '--no_validate', + action='store_true', + help='whether not to evaluate the checkpoint during training') parser.add_argument( '--gpus', type=int, @@ -84,7 +88,6 @@ def parse_args(): type=int, default=29500, help='port only works when launcher=="slurm"') - parser.add_argument( '--model_type', type=str, @@ -149,10 +152,12 @@ def main(): torch.backends.cudnn.benchmark = True # update configs according to CLI args - # if args.work_dir is not None and cfg.get('work_dir', None) is None: if args.work_dir is not None: cfg.work_dir = args.work_dir + if cfg.get('work_dir', None) is None: + cfg.work_dir = './work_dir' + # if `work_dir` is oss path, redirect `work_dir` to local path, add `oss_work_dir` point to oss path, # and use osssync hook to upload log and ckpt in work_dir to oss_work_dir if cfg.work_dir.startswith('oss://'): @@ -306,6 +311,7 @@ def main(): distributed=distributed, timestamp=timestamp, meta=meta, + validate=(not args.no_validate), use_fp16=args.fp16)