fast-reid/projects/FastTune/tune_net.py

243 lines
8.1 KiB
Python

#!/usr/bin/env python
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import logging
import os
import sys
from functools import partial
import ConfigSpace as CS
import ray
from hyperopt import hp
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
from ray.tune.suggest.bohb import TuneBOHB
from ray.tune.suggest.hyperopt import HyperOptSearch
sys.path.append('.')
from fastreid.config import get_cfg, CfgNode
from fastreid.engine import hooks
from fastreid.modeling import build_model
from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup
from fastreid.utils.events import CommonMetricPrinter
from fastreid.utils import comm
from fastreid.utils.file_io import PathManager
from autotuner import *
logger = logging.getLogger("fastreid.auto_tuner")
ray.init(dashboard_host='127.0.0.1')
class AutoTuner(DefaultTrainer):
def build_hooks(self):
r"""
Build a list of default hooks, including timing, evaluation,
checkpointing, lr scheduling, precise BN, writing events.
Returns:
list[HookBase]:
"""
cfg = self.cfg.clone()
cfg.defrost()
ret = [
hooks.IterationTimer(),
hooks.LRScheduler(self.optimizer, self.scheduler),
]
ret.append(hooks.LayerFreeze(
self.model,
cfg.MODEL.FREEZE_LAYERS,
cfg.SOLVER.FREEZE_ITERS,
cfg.SOLVER.FREEZE_FC_ITERS,
))
def test_and_save_results():
self._last_eval_results = self.test(self.cfg, self.model)
return self._last_eval_results
# Do evaluation after checkpointer, because then if it fails,
# we can use the saved checkpoint to debug.
ret.append(TuneReportHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
if comm.is_main_process():
# run writers in the end, so that evaluation metrics are written
ret.append(hooks.PeriodicWriter([CommonMetricPrinter(self.max_iter)], 200))
return ret
@classmethod
def build_model(cls, cfg):
model = build_model(cfg)
return model
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
return cfg
def update_config(cfg, config):
frozen = cfg.is_frozen()
cfg.defrost()
# cfg.SOLVER.BASE_LR = config["lr"]
# cfg.SOLVER.ETA_MIN_LR = config["lr"] * 0.0001
# cfg.SOLVER.DELAY_EPOCHS = int(config["delay_epochs"])
# cfg.MODEL.LOSSES.CE.SCALE = config["ce_scale"]
# cfg.MODEL.HEADS.SCALE = config["circle_scale"]
# cfg.MODEL.HEADS.MARGIN = config["circle_margin"]
# cfg.SOLVER.WEIGHT_DECAY = config["wd"]
# cfg.SOLVER.WEIGHT_DECAY_BIAS = config["wd_bias"]
cfg.SOLVER.IMS_PER_BATCH = config["bsz"]
cfg.DATALOADER.NUM_INSTANCE = config["num_inst"]
if frozen: cfg.freeze()
return cfg
def train_tuner(config, checkpoint_dir=None, cfg=None):
update_config(cfg, config)
tuner = AutoTuner(cfg)
# Load checkpoint if specific
if checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint.pth")
checkpoint = tuner.checkpointer.resume_or_load(path, resume=False)
tuner.start_epoch = checkpoint.get("epoch", -1) + 1
# Regular model training
tuner.train()
def main(args):
cfg = setup(args)
exp_metrics = dict(metric="score", mode="max")
if args.srch_algo == "hyperopt":
# Create a HyperOpt search space
search_space = {
# "lr": hp.loguniform("lr", np.log(1e-6), np.log(1e-3)),
# "delay_epochs": hp.randint("delay_epochs", 20, 60),
# "wd": hp.uniform("wd", 0, 1e-3),
# "wd_bias": hp.uniform("wd_bias", 0, 1e-3),
"bsz": hp.choice("bsz", [64, 96, 128, 160, 224, 256]),
"num_inst": hp.choice("num_inst", [2, 4, 8, 16, 32]),
# "ce_scale": hp.uniform("ce_scale", 0.1, 1.0),
# "circle_scale": hp.choice("circle_scale", [16, 32, 64, 128, 256]),
# "circle_margin": hp.uniform("circle_margin", 0, 1) * 0.4 + 0.1,
}
current_best_params = [{
"bsz": 0, # index of hp.choice list
"num_inst": 3,
}]
search_algo = HyperOptSearch(
search_space,
points_to_evaluate=current_best_params,
**exp_metrics)
if args.pbt:
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
**exp_metrics,
perturbation_interval=2,
hyperparam_mutations={
"bsz": [64, 96, 128, 160, 224, 256],
"num_inst": [2, 4, 8, 16, 32],
}
)
else:
scheduler = ASHAScheduler(
grace_period=2,
reduction_factor=3,
max_t=7,
**exp_metrics)
elif args.srch_algo == "bohb":
search_space = CS.ConfigurationSpace()
search_space.add_hyperparameters([
# CS.UniformFloatHyperparameter(name="lr", lower=1e-6, upper=1e-3, log=True),
# CS.UniformIntegerHyperparameter(name="delay_epochs", lower=20, upper=60),
# CS.UniformFloatHyperparameter(name="ce_scale", lower=0.1, upper=1.0),
# CS.UniformIntegerHyperparameter(name="circle_scale", lower=8, upper=128),
# CS.UniformFloatHyperparameter(name="circle_margin", lower=0.1, upper=0.5),
# CS.UniformFloatHyperparameter(name="wd", lower=0, upper=1e-3),
# CS.UniformFloatHyperparameter(name="wd_bias", lower=0, upper=1e-3),
CS.CategoricalHyperparameter(name="bsz", choices=[64, 96, 128, 160, 224, 256]),
CS.CategoricalHyperparameter(name="num_inst", choices=[2, 4, 8, 16, 32]),
# CS.CategoricalHyperparameter(name="autoaug_enabled", choices=[True, False]),
# CS.CategoricalHyperparameter(name="cj_enabled", choices=[True, False]),
])
search_algo = TuneBOHB(
search_space, max_concurrent=4, **exp_metrics)
scheduler = HyperBandForBOHB(
time_attr="training_iteration",
reduction_factor=3,
max_t=7,
**exp_metrics,
)
else:
raise ValueError("Search algorithm must be chosen from [hyperopt, bohb], but got {}".format(args.srch_algo))
reporter = CLIReporter(
parameter_columns=["bsz", "num_inst"],
metric_columns=["r1", "map", "training_iteration"])
analysis = tune.run(
partial(
train_tuner,
cfg=cfg),
resources_per_trial={"cpu": 4, "gpu": 1},
search_alg=search_algo,
num_samples=args.num_trials,
scheduler=scheduler,
progress_reporter=reporter,
local_dir=cfg.OUTPUT_DIR,
keep_checkpoints_num=10,
name=args.srch_algo)
best_trial = analysis.get_best_trial("score", "max", "last")
logger.info("Best trial config: {}".format(best_trial.config))
logger.info("Best trial final validation mAP: {}, Rank-1: {}".format(
best_trial.last_result["map"], best_trial.last_result["r1"]))
save_dict = dict(R1=best_trial.last_result["r1"].item(), mAP=best_trial.last_result["map"].item())
save_dict.update(best_trial.config)
path = os.path.join(cfg.OUTPUT_DIR, "best_config.yaml")
with PathManager.open(path, "w") as f:
f.write(CfgNode(save_dict).dump())
logger.info("Best config saved to {}".format(os.path.abspath(path)))
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--num-trials", type=int, default=8, help="number of tune trials")
parser.add_argument("--srch-algo", type=str, default="hyperopt",
help="search algorithms for hyperparameters search space")
parser.add_argument("--pbt", action="store_true", help="use population based training")
args = parser.parse_args()
print("Command Line Args:", args)
main(args)