mirror of https://github.com/JDAI-CV/fast-reid.git
243 lines
8.1 KiB
Python
243 lines
8.1 KiB
Python
#!/usr/bin/env python
|
|
# encoding: utf-8
|
|
"""
|
|
@author: sherlock
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import sys
|
|
from functools import partial
|
|
|
|
import ConfigSpace as CS
|
|
import ray
|
|
from hyperopt import hp
|
|
from ray import tune
|
|
from ray.tune import CLIReporter
|
|
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
|
|
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
|
|
from ray.tune.suggest.bohb import TuneBOHB
|
|
from ray.tune.suggest.hyperopt import HyperOptSearch
|
|
|
|
sys.path.append('.')
|
|
|
|
from fastreid.config import get_cfg, CfgNode
|
|
from fastreid.engine import hooks
|
|
from fastreid.modeling import build_model
|
|
from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup
|
|
from fastreid.utils.events import CommonMetricPrinter
|
|
from fastreid.utils import comm
|
|
from fastreid.utils.file_io import PathManager
|
|
|
|
from autotuner import *
|
|
|
|
logger = logging.getLogger("fastreid.auto_tuner")
|
|
|
|
ray.init(dashboard_host='127.0.0.1')
|
|
|
|
|
|
class AutoTuner(DefaultTrainer):
|
|
def build_hooks(self):
|
|
r"""
|
|
Build a list of default hooks, including timing, evaluation,
|
|
checkpointing, lr scheduling, precise BN, writing events.
|
|
Returns:
|
|
list[HookBase]:
|
|
"""
|
|
cfg = self.cfg.clone()
|
|
cfg.defrost()
|
|
|
|
ret = [
|
|
hooks.IterationTimer(),
|
|
hooks.LRScheduler(self.optimizer, self.scheduler),
|
|
]
|
|
|
|
ret.append(hooks.LayerFreeze(
|
|
self.model,
|
|
cfg.MODEL.FREEZE_LAYERS,
|
|
cfg.SOLVER.FREEZE_ITERS,
|
|
cfg.SOLVER.FREEZE_FC_ITERS,
|
|
))
|
|
|
|
def test_and_save_results():
|
|
self._last_eval_results = self.test(self.cfg, self.model)
|
|
return self._last_eval_results
|
|
|
|
# Do evaluation after checkpointer, because then if it fails,
|
|
# we can use the saved checkpoint to debug.
|
|
ret.append(TuneReportHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
|
|
|
|
if comm.is_main_process():
|
|
# run writers in the end, so that evaluation metrics are written
|
|
ret.append(hooks.PeriodicWriter([CommonMetricPrinter(self.max_iter)], 200))
|
|
|
|
return ret
|
|
|
|
@classmethod
|
|
def build_model(cls, cfg):
|
|
model = build_model(cfg)
|
|
return model
|
|
|
|
|
|
def setup(args):
|
|
"""
|
|
Create configs and perform basic setups.
|
|
"""
|
|
cfg = get_cfg()
|
|
cfg.merge_from_file(args.config_file)
|
|
cfg.merge_from_list(args.opts)
|
|
cfg.freeze()
|
|
default_setup(cfg, args)
|
|
return cfg
|
|
|
|
|
|
def update_config(cfg, config):
|
|
frozen = cfg.is_frozen()
|
|
cfg.defrost()
|
|
|
|
# cfg.SOLVER.BASE_LR = config["lr"]
|
|
# cfg.SOLVER.ETA_MIN_LR = config["lr"] * 0.0001
|
|
# cfg.SOLVER.DELAY_EPOCHS = int(config["delay_epochs"])
|
|
# cfg.MODEL.LOSSES.CE.SCALE = config["ce_scale"]
|
|
# cfg.MODEL.HEADS.SCALE = config["circle_scale"]
|
|
# cfg.MODEL.HEADS.MARGIN = config["circle_margin"]
|
|
# cfg.SOLVER.WEIGHT_DECAY = config["wd"]
|
|
# cfg.SOLVER.WEIGHT_DECAY_BIAS = config["wd_bias"]
|
|
cfg.SOLVER.IMS_PER_BATCH = config["bsz"]
|
|
cfg.DATALOADER.NUM_INSTANCE = config["num_inst"]
|
|
|
|
if frozen: cfg.freeze()
|
|
|
|
return cfg
|
|
|
|
|
|
def train_tuner(config, checkpoint_dir=None, cfg=None):
|
|
update_config(cfg, config)
|
|
|
|
tuner = AutoTuner(cfg)
|
|
# Load checkpoint if specific
|
|
if checkpoint_dir:
|
|
path = os.path.join(checkpoint_dir, "checkpoint.pth")
|
|
checkpoint = tuner.checkpointer.resume_or_load(path, resume=False)
|
|
tuner.start_epoch = checkpoint.get("epoch", -1) + 1
|
|
|
|
# Regular model training
|
|
tuner.train()
|
|
|
|
|
|
def main(args):
|
|
cfg = setup(args)
|
|
|
|
exp_metrics = dict(metric="score", mode="max")
|
|
|
|
if args.srch_algo == "hyperopt":
|
|
# Create a HyperOpt search space
|
|
search_space = {
|
|
# "lr": hp.loguniform("lr", np.log(1e-6), np.log(1e-3)),
|
|
# "delay_epochs": hp.randint("delay_epochs", 20, 60),
|
|
# "wd": hp.uniform("wd", 0, 1e-3),
|
|
# "wd_bias": hp.uniform("wd_bias", 0, 1e-3),
|
|
"bsz": hp.choice("bsz", [64, 96, 128, 160, 224, 256]),
|
|
"num_inst": hp.choice("num_inst", [2, 4, 8, 16, 32]),
|
|
# "ce_scale": hp.uniform("ce_scale", 0.1, 1.0),
|
|
# "circle_scale": hp.choice("circle_scale", [16, 32, 64, 128, 256]),
|
|
# "circle_margin": hp.uniform("circle_margin", 0, 1) * 0.4 + 0.1,
|
|
}
|
|
|
|
current_best_params = [{
|
|
"bsz": 0, # index of hp.choice list
|
|
"num_inst": 3,
|
|
}]
|
|
|
|
search_algo = HyperOptSearch(
|
|
search_space,
|
|
points_to_evaluate=current_best_params,
|
|
**exp_metrics)
|
|
|
|
if args.pbt:
|
|
scheduler = PopulationBasedTraining(
|
|
time_attr="training_iteration",
|
|
**exp_metrics,
|
|
perturbation_interval=2,
|
|
hyperparam_mutations={
|
|
"bsz": [64, 96, 128, 160, 224, 256],
|
|
"num_inst": [2, 4, 8, 16, 32],
|
|
}
|
|
)
|
|
else:
|
|
scheduler = ASHAScheduler(
|
|
grace_period=2,
|
|
reduction_factor=3,
|
|
max_t=7,
|
|
**exp_metrics)
|
|
|
|
elif args.srch_algo == "bohb":
|
|
search_space = CS.ConfigurationSpace()
|
|
search_space.add_hyperparameters([
|
|
# CS.UniformFloatHyperparameter(name="lr", lower=1e-6, upper=1e-3, log=True),
|
|
# CS.UniformIntegerHyperparameter(name="delay_epochs", lower=20, upper=60),
|
|
# CS.UniformFloatHyperparameter(name="ce_scale", lower=0.1, upper=1.0),
|
|
# CS.UniformIntegerHyperparameter(name="circle_scale", lower=8, upper=128),
|
|
# CS.UniformFloatHyperparameter(name="circle_margin", lower=0.1, upper=0.5),
|
|
# CS.UniformFloatHyperparameter(name="wd", lower=0, upper=1e-3),
|
|
# CS.UniformFloatHyperparameter(name="wd_bias", lower=0, upper=1e-3),
|
|
CS.CategoricalHyperparameter(name="bsz", choices=[64, 96, 128, 160, 224, 256]),
|
|
CS.CategoricalHyperparameter(name="num_inst", choices=[2, 4, 8, 16, 32]),
|
|
# CS.CategoricalHyperparameter(name="autoaug_enabled", choices=[True, False]),
|
|
# CS.CategoricalHyperparameter(name="cj_enabled", choices=[True, False]),
|
|
])
|
|
|
|
search_algo = TuneBOHB(
|
|
search_space, max_concurrent=4, **exp_metrics)
|
|
|
|
scheduler = HyperBandForBOHB(
|
|
time_attr="training_iteration",
|
|
reduction_factor=3,
|
|
max_t=7,
|
|
**exp_metrics,
|
|
)
|
|
|
|
else:
|
|
raise ValueError("Search algorithm must be chosen from [hyperopt, bohb], but got {}".format(args.srch_algo))
|
|
|
|
reporter = CLIReporter(
|
|
parameter_columns=["bsz", "num_inst"],
|
|
metric_columns=["r1", "map", "training_iteration"])
|
|
|
|
analysis = tune.run(
|
|
partial(
|
|
train_tuner,
|
|
cfg=cfg),
|
|
resources_per_trial={"cpu": 4, "gpu": 1},
|
|
search_alg=search_algo,
|
|
num_samples=args.num_trials,
|
|
scheduler=scheduler,
|
|
progress_reporter=reporter,
|
|
local_dir=cfg.OUTPUT_DIR,
|
|
keep_checkpoints_num=10,
|
|
name=args.srch_algo)
|
|
|
|
best_trial = analysis.get_best_trial("score", "max", "last")
|
|
logger.info("Best trial config: {}".format(best_trial.config))
|
|
logger.info("Best trial final validation mAP: {}, Rank-1: {}".format(
|
|
best_trial.last_result["map"], best_trial.last_result["r1"]))
|
|
|
|
save_dict = dict(R1=best_trial.last_result["r1"].item(), mAP=best_trial.last_result["map"].item())
|
|
save_dict.update(best_trial.config)
|
|
path = os.path.join(cfg.OUTPUT_DIR, "best_config.yaml")
|
|
with PathManager.open(path, "w") as f:
|
|
f.write(CfgNode(save_dict).dump())
|
|
logger.info("Best config saved to {}".format(os.path.abspath(path)))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = default_argument_parser()
|
|
parser.add_argument("--num-trials", type=int, default=8, help="number of tune trials")
|
|
parser.add_argument("--srch-algo", type=str, default="hyperopt",
|
|
help="search algorithms for hyperparameters search space")
|
|
parser.add_argument("--pbt", action="store_true", help="use population based training")
|
|
args = parser.parse_args()
|
|
print("Command Line Args:", args)
|
|
main(args)
|