mirror of https://github.com/alibaba/EasyCV.git
parent
265d4cc6e5
commit
acd0619ef9
|
@ -167,12 +167,15 @@ def train_model(model,
|
|||
runner.register_hook(DistSamplerSeedHook())
|
||||
|
||||
# register eval hooks
|
||||
validate = False
|
||||
if 'eval_pipelines' in cfg:
|
||||
if isinstance(cfg.eval_pipelines, dict):
|
||||
cfg.eval_pipelines = [cfg.eval_pipelines]
|
||||
if len(cfg.eval_pipelines) > 0:
|
||||
validate = True
|
||||
if validate:
|
||||
if 'eval_pipelines' not in cfg:
|
||||
runner.logger.warning(
|
||||
'Not find `eval_pipelines` in cfg, skip validation!')
|
||||
validate = False
|
||||
else:
|
||||
if isinstance(cfg.eval_pipelines, dict):
|
||||
cfg.eval_pipelines = [cfg.eval_pipelines]
|
||||
assert len(cfg.eval_pipelines) > 0
|
||||
runner.logger.info('open validate hook')
|
||||
|
||||
best_metric_name = [
|
||||
|
|
|
@ -57,6 +57,10 @@ def parse_args():
|
|||
parser.add_argument('--load_from', help='the checkpoint file to load from')
|
||||
parser.add_argument(
|
||||
'--pretrained', default=None, help='pretrained model file')
|
||||
parser.add_argument(
|
||||
'--no_validate',
|
||||
action='store_true',
|
||||
help='whether not to evaluate the checkpoint during training')
|
||||
parser.add_argument(
|
||||
'--gpus',
|
||||
type=int,
|
||||
|
@ -84,7 +88,6 @@ def parse_args():
|
|||
type=int,
|
||||
default=29500,
|
||||
help='port only works when launcher=="slurm"')
|
||||
|
||||
parser.add_argument(
|
||||
'--model_type',
|
||||
type=str,
|
||||
|
@ -149,10 +152,12 @@ def main():
|
|||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# update configs according to CLI args
|
||||
# if args.work_dir is not None and cfg.get('work_dir', None) is None:
|
||||
if args.work_dir is not None:
|
||||
cfg.work_dir = args.work_dir
|
||||
|
||||
if cfg.get('work_dir', None) is None:
|
||||
cfg.work_dir = './work_dir'
|
||||
|
||||
# if `work_dir` is oss path, redirect `work_dir` to local path, add `oss_work_dir` point to oss path,
|
||||
# and use osssync hook to upload log and ckpt in work_dir to oss_work_dir
|
||||
if cfg.work_dir.startswith('oss://'):
|
||||
|
@ -306,6 +311,7 @@ def main():
|
|||
distributed=distributed,
|
||||
timestamp=timestamp,
|
||||
meta=meta,
|
||||
validate=(not args.no_validate),
|
||||
use_fp16=args.fp16)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue