add no_validation arg (#280)

* add no_validation arg
pull/283/head
Cathy0908 2023-02-10 10:18:50 +08:00 committed by GitHub
parent 265d4cc6e5
commit acd0619ef9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 8 deletions

View File

@ -167,12 +167,15 @@ def train_model(model,
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
validate = False
if 'eval_pipelines' in cfg:
if isinstance(cfg.eval_pipelines, dict):
cfg.eval_pipelines = [cfg.eval_pipelines]
if len(cfg.eval_pipelines) > 0:
validate = True
if validate:
if 'eval_pipelines' not in cfg:
runner.logger.warning(
'Not find `eval_pipelines` in cfg, skip validation!')
validate = False
else:
if isinstance(cfg.eval_pipelines, dict):
cfg.eval_pipelines = [cfg.eval_pipelines]
assert len(cfg.eval_pipelines) > 0
runner.logger.info('open validate hook')
best_metric_name = [

View File

@ -57,6 +57,10 @@ def parse_args():
parser.add_argument('--load_from', help='the checkpoint file to load from')
parser.add_argument(
'--pretrained', default=None, help='pretrained model file')
parser.add_argument(
'--no_validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
parser.add_argument(
'--gpus',
type=int,
@ -84,7 +88,6 @@ def parse_args():
type=int,
default=29500,
help='port only works when launcher=="slurm"')
parser.add_argument(
'--model_type',
type=str,
@ -149,10 +152,12 @@ def main():
torch.backends.cudnn.benchmark = True
# update configs according to CLI args
# if args.work_dir is not None and cfg.get('work_dir', None) is None:
if args.work_dir is not None:
cfg.work_dir = args.work_dir
if cfg.get('work_dir', None) is None:
cfg.work_dir = './work_dir'
# if `work_dir` is oss path, redirect `work_dir` to local path, add `oss_work_dir` point to oss path,
# and use osssync hook to upload log and ckpt in work_dir to oss_work_dir
if cfg.work_dir.startswith('oss://'):
@ -306,6 +311,7 @@ def main():
distributed=distributed,
timestamp=timestamp,
meta=meta,
validate=(not args.no_validate),
use_fp16=args.fp16)