mirror of https://github.com/JDAI-CV/fast-reid.git
add more hpo search algorithms
Summary: support hyperopt, bohb and population based trainingpull/365/head
parent
bd395917a8
commit
8d55a0bbd4
projects/HPOReID
|
@ -6,15 +6,17 @@ Install the following
|
|||
|
||||
```bash
|
||||
pip install 'ray[tune]'
|
||||
pip install hpbandster ConfigSpace
|
||||
pip install hpbandster ConfigSpace hyperopt
|
||||
```
|
||||
|
||||
## Training
|
||||
## Example
|
||||
|
||||
To train a model with `BOHB`, run
|
||||
This is an example for tuning `batch_size` and `num_instance` automatically.
|
||||
|
||||
To train hyperparameter optimization with BOHB(Bayesian Optimization with HyperBand) search algorithm, run
|
||||
|
||||
```bash
|
||||
python3 projects/HPOReID/train_hpo.py --config-file projects/HPOReID/configs/baseline.yml
|
||||
python3 projects/HPOReID/train_hpo.py --config-file projects/HPOReID/configs/baseline.yml --srch-algo "bohb"
|
||||
```
|
||||
|
||||
## Known issues
|
||||
|
|
|
@ -90,4 +90,4 @@ DATASETS:
|
|||
|
||||
CUDNN_BENCHMARK: True
|
||||
|
||||
OUTPUT_DIR: "projects/HPOReID/logs/dukemtmc/r34-ibn_bohb"
|
||||
OUTPUT_DIR: "projects/HPOReID/logs/dukemtmc/r34-ibn_bohb_bsz_num_"
|
||||
|
|
|
@ -38,10 +38,12 @@ class TuneReportHook(EvalHook):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
self.step += 1
|
||||
|
||||
# Here we save a checkpoint. It is automatically registered with
|
||||
# Ray Tuen and will potentially be passed as the `checkpoint_dir`
|
||||
# Ray Tune and will potentially be passed as the `checkpoint_dir`
|
||||
# parameter in future iterations.
|
||||
with tune.checkpoint_dir(step=self.step) as checkpoint_dir:
|
||||
additional_state = {"iteration": int(self.trainer.iter)}
|
||||
Checkpointer(
|
||||
# Assume you want to save checkpoints together with logs/statistics
|
||||
self.trainer.model,
|
||||
|
@ -49,6 +51,7 @@ class TuneReportHook(EvalHook):
|
|||
save_to_disk=True,
|
||||
optimizer=self.trainer.optimizer,
|
||||
scheduler=self.trainer.scheduler,
|
||||
).save(name="checkpoint")
|
||||
).save(name="checkpoint", **additional_state)
|
||||
|
||||
tune.report(r1=results['Rank-1'], map=results['mAP'], score=(results['Rank-1'] + results['mAP']) / 2)
|
||||
metrics = dict(r1=results['Rank-1'], map=results['mAP'], score=(results['Rank-1'] + results['mAP']) / 2)
|
||||
tune.report(**metrics)
|
||||
|
|
|
@ -12,20 +12,24 @@ from functools import partial
|
|||
|
||||
import ConfigSpace as CS
|
||||
import ray
|
||||
from hyperopt import hp
|
||||
from ray import tune
|
||||
from ray.tune import CLIReporter
|
||||
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
|
||||
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
|
||||
from ray.tune.suggest.bohb import TuneBOHB
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
|
||||
sys.path.append('.')
|
||||
|
||||
from fastreid.config import get_cfg
|
||||
from fastreid.config import get_cfg, CfgNode
|
||||
from fastreid.engine import hooks
|
||||
from fastreid.modeling import build_model
|
||||
from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup
|
||||
from fastreid.utils.events import CommonMetricPrinter
|
||||
from fastreid.utils.file_io import PathManager
|
||||
|
||||
from hporeid import *
|
||||
from naic.tune_hooks import TuneReportHook
|
||||
|
||||
logger = logging.getLogger("fastreid.project.tune")
|
||||
|
||||
|
@ -89,30 +93,40 @@ def setup(args):
|
|||
return cfg
|
||||
|
||||
|
||||
def train_reid_tune(cfg, config, checkpoint_dir=None):
|
||||
def update_config(cfg, config):
|
||||
cfg.defrost()
|
||||
|
||||
# Hyperparameter tuning
|
||||
cfg.SOLVER.BASE_LR = config["lr"]
|
||||
cfg.SOLVER.WEIGHT_DECAY = config["wd"]
|
||||
cfg.SOLVER.WEIGHT_DECAY_BIAS = config["wd_bias"]
|
||||
# lr, weight decay
|
||||
# cfg.SOLVER.BASE_LR = config["lr"]
|
||||
# cfg.SOLVER.ETA_MIN_LR = config["lr"] * 0.0022
|
||||
# cfg.SOLVER.DELAY_ITERS = config["delay_iters"]
|
||||
# cfg.SOLVER.WEIGHT_DECAY = config["wd"]
|
||||
# cfg.SOLVER.WEIGHT_DECAY_BIAS = config["wd_bias"]
|
||||
|
||||
# batch size, number of instance
|
||||
cfg.SOLVER.IMS_PER_BATCH = config["bsz"]
|
||||
cfg.DATALOADER.NUM_INSTANCE = config["num_inst"]
|
||||
cfg.SOLVER.DELAY_ITERS = config["delay_iters"]
|
||||
cfg.SOLVER.ETA_MIN_LR = config["lr"] * 0.0022
|
||||
cfg.MODEL.LOSSES.CE.SCALE = config["ce_scale"]
|
||||
cfg.MODEL.HEADS.SCALE = config["circle_scale"]
|
||||
cfg.MODEL.HEADS.MARGIN = config["circle_margin"]
|
||||
cfg.INPUT.DO_AUTOAUG = config["autoaug_enabled"]
|
||||
cfg.INPUT.CJ.ENABLED = config["cj_enabled"]
|
||||
|
||||
# loss related
|
||||
# cfg.MODEL.LOSSES.CE.SCALE = config["ce_scale"]
|
||||
# cfg.MODEL.HEADS.SCALE = config["circle_scale"]
|
||||
# cfg.MODEL.HEADS.MARGIN = config["circle_margin"]
|
||||
|
||||
# data augmentation
|
||||
# cfg.INPUT.DO_AUTOAUG = config["autoaug_enabled"]
|
||||
# cfg.INPUT.CJ.ENABLED = config["cj_enabled"]
|
||||
return cfg
|
||||
|
||||
|
||||
def train_reid_tune(config, checkpoint_dir=None, cfg=None):
|
||||
update_config(cfg, config)
|
||||
|
||||
trainer = HyperTuneTrainer(cfg)
|
||||
|
||||
# Load checkpoint if specific
|
||||
if checkpoint_dir:
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
checkpoint = trainer.checkpointer.resume_or_load(path)
|
||||
if checkpoint.checkpointer.has_checkpoint():
|
||||
trainer.start_iter = checkpoint.get("iteration", -1) + 1
|
||||
path = os.path.join(checkpoint_dir, "checkpoint.pth")
|
||||
checkpoint = trainer.checkpointer.resume_or_load(path, resume=False)
|
||||
trainer.start_iter = checkpoint.get("iteration", -1) + 1
|
||||
|
||||
# Regular model training
|
||||
trainer.train()
|
||||
|
@ -121,57 +135,109 @@ def train_reid_tune(cfg, config, checkpoint_dir=None):
|
|||
def main(args):
|
||||
cfg = setup(args)
|
||||
|
||||
search_space = CS.ConfigurationSpace()
|
||||
search_space.add_hyperparameters([
|
||||
CS.UniformFloatHyperparameter(name="lr", lower=1e-6, upper=1e-3),
|
||||
CS.UniformFloatHyperparameter(name="wd", lower=0, upper=1e-3),
|
||||
CS.UniformFloatHyperparameter(name="wd_bias", lower=0, upper=1e-3),
|
||||
CS.CategoricalHyperparameter(name="bsz", choices=[64, 96, 128, 160, 224, 256]),
|
||||
CS.CategoricalHyperparameter(name="num_inst", choices=[2, 4, 8, 16, 32]),
|
||||
CS.UniformIntegerHyperparameter(name="delay_iters", lower=20, upper=60),
|
||||
CS.UniformFloatHyperparameter(name="ce_scale", lower=0.1, upper=1.0),
|
||||
CS.UniformIntegerHyperparameter(name="circle_scale", lower=8, upper=256),
|
||||
CS.UniformFloatHyperparameter(name="circle_margin", lower=0.1, upper=0.5),
|
||||
CS.CategoricalHyperparameter(name="autoaug_enabled", choices=[True, False]),
|
||||
CS.CategoricalHyperparameter(name="cj_enabled", choices=[True, False]),
|
||||
]
|
||||
)
|
||||
|
||||
exp_metrics = dict(metric="score", mode="max")
|
||||
bohb_hyperband = HyperBandForBOHB(
|
||||
time_attr="training_iteration",
|
||||
max_t=7,
|
||||
**exp_metrics,
|
||||
)
|
||||
bohb_search = TuneBOHB(
|
||||
search_space, max_concurrent=4, **exp_metrics)
|
||||
|
||||
if args.srch_algo == "hyperopt":
|
||||
# Create a HyperOpt search space
|
||||
search_space = {
|
||||
# "lr": hp.uniform("lr", 1e-6, 1e-3),
|
||||
# "delay_iters": hp.randint("delay_iters", 40) + 10,
|
||||
# "wd": hp.uniform("wd", 0, 1e-3),
|
||||
# "wd_bias": hp.uniform("wd_bias", 0, 1e-3),
|
||||
"bsz": hp.choice("bsz", [64, 96, 128, 160, 224, 256]),
|
||||
"num_inst": hp.choice("num_inst", [2, 4, 8, 16, 32]),
|
||||
# "ce_scale": hp.uniform("ce_scale", 0.1, 1.0),
|
||||
# "circle_scale": hp.choice("circle_scale", [16, 32, 64, 128, 256]),
|
||||
# "circle_margin": hp.uniform("circle_margin", 0, 1) * 0.4 + 0.1,
|
||||
# "autoaug_enabled": hp.choice("autoaug_enabled", [True, False]),
|
||||
# "cj_enabled": hp.choice("cj_enabled", [True, False]),
|
||||
}
|
||||
|
||||
search_algo = HyperOptSearch(search_space, **exp_metrics)
|
||||
|
||||
if args.pbt:
|
||||
scheduler = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
**exp_metrics,
|
||||
perturbation_interval=2,
|
||||
hyperparam_mutations={
|
||||
"bsz": [64, 96, 128, 160, 224, 256],
|
||||
"num_inst": [2, 4, 8, 16, 32],
|
||||
}
|
||||
)
|
||||
else:
|
||||
scheduler = ASHAScheduler(
|
||||
metric="score",
|
||||
mode="max",
|
||||
max_t=10,
|
||||
grace_period=1,
|
||||
reduction_factor=2)
|
||||
|
||||
elif args.srch_algo == "bohb":
|
||||
search_space = CS.ConfigurationSpace()
|
||||
search_space.add_hyperparameters([
|
||||
# CS.UniformFloatHyperparameter(name="lr", lower=1e-6, upper=1e-3),
|
||||
# CS.UniformIntegerHyperparameter(name="delay_iters", lower=20, upper=60),
|
||||
# CS.UniformFloatHyperparameter(name="wd", lower=0, upper=1e-3),
|
||||
# CS.UniformFloatHyperparameter(name="wd_bias", lower=0, upper=1e-3),
|
||||
CS.CategoricalHyperparameter(name="bsz", choices=[64, 96, 128, 160, 224, 256]),
|
||||
CS.CategoricalHyperparameter(name="num_inst", choices=[2, 4, 8, 16, 32]),
|
||||
# CS.UniformFloatHyperparameter(name="ce_scale", lower=0.1, upper=1.0),
|
||||
# CS.UniformIntegerHyperparameter(name="circle_scale", lower=8, upper=256),
|
||||
# CS.UniformFloatHyperparameter(name="circle_margin", lower=0.1, upper=0.5),
|
||||
# CS.CategoricalHyperparameter(name="autoaug_enabled", choices=[True, False]),
|
||||
# CS.CategoricalHyperparameter(name="cj_enabled", choices=[True, False]),
|
||||
])
|
||||
|
||||
search_algo = TuneBOHB(
|
||||
search_space, max_concurrent=4, **exp_metrics)
|
||||
|
||||
scheduler = HyperBandForBOHB(
|
||||
time_attr="training_iteration",
|
||||
reduction_factor=2,
|
||||
max_t=9,
|
||||
**exp_metrics,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError("Search algorithm must be chosen from [hyperopt, bohb], but got {}".format(args.srch_algo))
|
||||
|
||||
reporter = CLIReporter(
|
||||
parameter_columns=["bsz", "num_inst", "lr"],
|
||||
parameter_columns=["bsz", "num_inst"],
|
||||
metric_columns=["r1", "map", "training_iteration"])
|
||||
|
||||
analysis = tune.run(
|
||||
partial(
|
||||
train_reid_tune,
|
||||
cfg),
|
||||
resources_per_trial={"cpu": 10, "gpu": 1},
|
||||
search_alg=bohb_search,
|
||||
cfg=cfg),
|
||||
resources_per_trial={"cpu": 12, "gpu": 1},
|
||||
search_alg=search_algo,
|
||||
num_samples=args.num_samples,
|
||||
scheduler=bohb_hyperband,
|
||||
scheduler=scheduler,
|
||||
progress_reporter=reporter,
|
||||
local_dir=cfg.OUTPUT_DIR,
|
||||
keep_checkpoints_num=4,
|
||||
name="bohb")
|
||||
keep_checkpoints_num=10,
|
||||
name=args.srch_algo)
|
||||
|
||||
best_trial = analysis.get_best_trial("map", "max", "last")
|
||||
best_trial = analysis.get_best_trial("score", "max", "last")
|
||||
logger.info("Best trial config: {}".format(best_trial.config))
|
||||
logger.info("Best trial final validation mAP: {}, Rank-1: {}".format(
|
||||
best_trial.last_result["map"], best_trial.last_result["r1"]))
|
||||
|
||||
save_dict = dict(R1=best_trial.last_result["r1"], mAP=best_trial.last_result["map"])
|
||||
save_dict.update(best_trial.config)
|
||||
path = os.path.join(cfg.OUTPUT_DIR, "best_config.yaml")
|
||||
with PathManager.open(path, "w") as f:
|
||||
f.write(CfgNode(save_dict).dump())
|
||||
logger.info("Best config saved to {}".format(os.path.abspath(path)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = default_argument_parser()
|
||||
parser.add_argument("--num-samples", type=int, default=20, help="number of tune trials")
|
||||
parser.add_argument("--num-trials", type=int, default=12, help="number of tune trials")
|
||||
parser.add_argument("--srch-algo", type=str, default="bohb",
|
||||
help="search algorithms for hyperparameters search space")
|
||||
parser.add_argument("--pbt", action="store_true", help="use population based training")
|
||||
args = parser.parse_args()
|
||||
print("Command Line Args:", args)
|
||||
main(args)
|
||||
|
|
Loading…
Reference in New Issue