fix typro in hpo training

pull/365/head
liaoxingyu 2020-10-10 18:55:32 +08:00
parent 8d55a0bbd4
commit e5c3ff3b5b
2 changed files with 5 additions and 4 deletions

View File

@ -90,4 +90,4 @@ DATASETS:
CUDNN_BENCHMARK: True CUDNN_BENCHMARK: True
OUTPUT_DIR: "projects/HPOReID/logs/dukemtmc/r34-ibn_bohb_bsz_num_" OUTPUT_DIR: "projects/HPOReID/logs/dukemtmc/r34-ibn_bohb_bsz_num-inst"

View File

@ -29,7 +29,7 @@ from fastreid.engine import DefaultTrainer, default_argument_parser, default_set
from fastreid.utils.events import CommonMetricPrinter from fastreid.utils.events import CommonMetricPrinter
from fastreid.utils.file_io import PathManager from fastreid.utils.file_io import PathManager
from naic.tune_hooks import TuneReportHook from hporeid import *
logger = logging.getLogger("fastreid.project.tune") logger = logging.getLogger("fastreid.project.tune")
@ -180,7 +180,8 @@ def main(args):
# CS.UniformIntegerHyperparameter(name="delay_iters", lower=20, upper=60), # CS.UniformIntegerHyperparameter(name="delay_iters", lower=20, upper=60),
# CS.UniformFloatHyperparameter(name="wd", lower=0, upper=1e-3), # CS.UniformFloatHyperparameter(name="wd", lower=0, upper=1e-3),
# CS.UniformFloatHyperparameter(name="wd_bias", lower=0, upper=1e-3), # CS.UniformFloatHyperparameter(name="wd_bias", lower=0, upper=1e-3),
CS.CategoricalHyperparameter(name="bsz", choices=[64, 96, 128, 160, 224, 256]), # CS.CategoricalHyperparameter(name="bsz", choices=[64, 96, 128, 160, 224, 256]),
CS.CategoricalHyperparameter(name="bsz", choices=[32, 64]),
CS.CategoricalHyperparameter(name="num_inst", choices=[2, 4, 8, 16, 32]), CS.CategoricalHyperparameter(name="num_inst", choices=[2, 4, 8, 16, 32]),
# CS.UniformFloatHyperparameter(name="ce_scale", lower=0.1, upper=1.0), # CS.UniformFloatHyperparameter(name="ce_scale", lower=0.1, upper=1.0),
# CS.UniformIntegerHyperparameter(name="circle_scale", lower=8, upper=256), # CS.UniformIntegerHyperparameter(name="circle_scale", lower=8, upper=256),
@ -212,7 +213,7 @@ def main(args):
cfg=cfg), cfg=cfg),
resources_per_trial={"cpu": 12, "gpu": 1}, resources_per_trial={"cpu": 12, "gpu": 1},
search_alg=search_algo, search_alg=search_algo,
num_samples=args.num_samples, num_samples=args.num_trials,
scheduler=scheduler, scheduler=scheduler,
progress_reporter=reporter, progress_reporter=reporter,
local_dir=cfg.OUTPUT_DIR, local_dir=cfg.OUTPUT_DIR,