support hpo project

Summary: support BOHB search algorithm for reid.
pull/299/head
liaoxingyu 2020-10-01 18:10:02 +08:00
parent 4629172a90
commit 85cebe311f
6 changed files with 333 additions and 0 deletions

View File

@ -4,6 +4,7 @@ FastReID is a research platform that implements state-of-the-art re-identificati
## What's New
- [Oct 2020] Added the [Hyper-Parameter Optimization](https://github.com/JDAI-CV/fast-reid/tree/master/projects/HPOReID) based on fastreid. See `projects/HPOReID`.
- [Sep 2020] Added the [person attribute recognition](https://github.com/JDAI-CV/fast-reid/tree/master/projects/attribute_recognition) based on fastreid. See `projects/attribute_recognition`.
- [Sep 2020] Automatic Mixed Precision training is supported with pytorch1.6 built-in `torch.cuda.amp`. Set `cfg.SOLVER.AMP_ENABLED=True` to switch it on.
- [Aug 2020] [Model Distillation](https://github.com/JDAI-CV/fast-reid/tree/master/projects/DistillReID) is supported, thanks for [guan'an wang](https://github.com/wangguanan)'s contribution.

View File

@ -0,0 +1,21 @@
# Hyper-Parameter Optimization in FastReID
This project includes training reid models with hyper-parameter optimization.
Install the following
```bash
pip install 'ray[tune]'
pip install hpbandster ConfigSpace
```
## Training
To train a model with `BOHB`, run
```bash
python3 projects/HPOReID/train_hpo.py --config-file projects/HPOReID/configs/baseline.yml
```
## Known issues
todo

View File

@ -0,0 +1,93 @@
MODEL:
META_ARCHITECTURE: "Baseline"
FREEZE_LAYERS: ["backbone"]
BACKBONE:
NAME: "build_resnet_backbone"
DEPTH: "34x"
LAST_STRIDE: 1
FEAT_DIM: 512
NORM: "BN"
WITH_NL: False
WITH_IBN: True
PRETRAIN: True
PRETRAIN_PATH: "/export/home/lxy/.cache/torch/checkpoints/resnet34_ibn_a-94bc1577.pth"
HEADS:
NAME: "EmbeddingHead"
NORM: "BN"
NECK_FEAT: "after"
EMBEDDING_DIM: 0
POOL_LAYER: "gempool"
CLS_LAYER: "circleSoftmax"
SCALE: 64
MARGIN: 0.35
LOSSES:
NAME: ("CrossEntropyLoss", "TripletLoss",)
CE:
EPSILON: 0.1
SCALE: 1.
TRI:
MARGIN: 0.0
HARD_MINING: True
NORM_FEAT: False
SCALE: 1.
CIRCLE:
MARGIN: 0.25
ALPHA: 96
SCALE: 1.0
INPUT:
SIZE_TRAIN: [256, 128]
SIZE_TEST: [256, 128]
DO_AUTOAUG: True
REA:
ENABLED: True
CJ:
ENABLED: True
DO_PAD: True
DATALOADER:
PK_SAMPLER: True
NAIVE_WAY: False
NUM_INSTANCE: 16
NUM_WORKERS: 8
SOLVER:
AMP_ENABLED: False
OPT: "Adam"
SCHED: "WarmupCosineAnnealingLR"
MAX_ITER: 60
BASE_LR: 0.00035
BIAS_LR_FACTOR: 1.
WEIGHT_DECAY: 0.0005
WEIGHT_DECAY_BIAS: 0.0
IMS_PER_BATCH: 64
DELAY_ITERS: 30
ETA_MIN_LR: 0.00000077
FREEZE_ITERS: 5
WARMUP_FACTOR: 0.01
WARMUP_ITERS: 5
CHECKPOINT_PERIOD: 100
TEST:
EVAL_PERIOD: 10
IMS_PER_BATCH: 256
DATASETS:
NAMES: ("DukeMTMC",)
TESTS: ("DukeMTMC",)
COMBINEALL: False
CUDNN_BENCHMARK: True
OUTPUT_DIR: "projects/HPOReID/logs/dukemtmc/r34-ibn_bohb"

View File

@ -0,0 +1,7 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
from .tune_hooks import TuneReportHook

View File

@ -0,0 +1,34 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import torch
from ray import tune
from fastreid.engine.hooks import EvalHook, flatten_results_dict
class TuneReportHook(EvalHook):
def _do_eval(self):
results = self._func()
if results:
assert isinstance(
results, dict
), "Eval function must return a dict. Got {} instead.".format(results)
flattened_results = flatten_results_dict(results)
for k, v in flattened_results.items():
try:
v = float(v)
except Exception:
raise ValueError(
"[EvalHook] eval_function should return a nested dict of float. "
"Got '{}: {}' instead.".format(k, v)
)
# Remove extra memory cache of main process due to evaluation
torch.cuda.empty_cache()
tune.report(r1=results['Rank-1'], map=results['mAP'], score=(results['Rank-1'] + results['mAP']) / 2)

View File

@ -0,0 +1,177 @@
#!/usr/bin/env python
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import logging
import os
import sys
from functools import partial
import ConfigSpace as CS
import ray
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
from ray.tune.suggest.bohb import TuneBOHB
sys.path.append('.')
from fastreid.config import get_cfg
from fastreid.engine import hooks
from fastreid.modeling import build_model
from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup
from fastreid.utils.events import CommonMetricPrinter
from hporeid import *
logger = logging.getLogger("fastreid.project.tune")
ray.init(dashboard_host='127.0.0.1')
class HyperTuneTrainer(DefaultTrainer):
def build_hooks(self):
r"""
Build a list of default hooks, including timing, evaluation,
checkpointing, lr scheduling, precise BN, writing events.
Returns:
list[HookBase]:
"""
cfg = self.cfg.clone()
cfg.defrost()
ret = [
hooks.IterationTimer(),
hooks.LRScheduler(self.optimizer, self.scheduler),
]
if cfg.MODEL.FREEZE_LAYERS != [''] and cfg.SOLVER.FREEZE_ITERS > 0:
freeze_layers = ",".join(cfg.MODEL.FREEZE_LAYERS)
logger.info(f'Freeze layer group "{freeze_layers}" training for {cfg.SOLVER.FREEZE_ITERS:d} iterations')
ret.append(hooks.FreezeLayer(
self.model,
self.optimizer,
cfg.MODEL.FREEZE_LAYERS,
cfg.SOLVER.FREEZE_ITERS,
))
def test_and_save_results():
self._last_eval_results = self.test(self.cfg, self.model)
return self._last_eval_results
# Do evaluation after checkpointer, because then if it fails,
# we can use the saved checkpoint to debug.
ret.append(TuneReportHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
# run writers in the end, so that evaluation metrics are written
ret.append(hooks.PeriodicWriter([CommonMetricPrinter(self.max_iter)], 200))
return ret
@classmethod
def build_model(cls, cfg):
model = build_model(cfg)
return model
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
return cfg
def train_reid_tune(cfg, config, checkpoint_dir=None):
cfg.defrost()
# Hyperparameter tuning
cfg.SOLVER.BASE_LR = config["lr"]
cfg.SOLVER.WEIGHT_DECAY = config["wd"]
cfg.SOLVER.WEIGHT_DECAY_BIAS = config["wd_bias"]
cfg.SOLVER.IMS_PER_BATCH = config["bsz"]
cfg.DATALOADER.NUM_INSTANCE = config["num_inst"]
cfg.SOLVER.DELAY_ITERS = config["delay_iters"]
cfg.SOLVER.ETA_MIN_LR = config["lr"] * 0.0022
cfg.MODEL.LOSSES.CE.SCALE = config["ce_scale"]
cfg.MODEL.HEADS.SCALE = config["circle_scale"]
cfg.MODEL.HEADS.MARGIN = config["circle_margin"]
cfg.INPUT.DO_AUTOAUG = config["autoaug_enabled"]
cfg.INPUT.CJ.ENABLED = config["cj_enabled"]
trainer = HyperTuneTrainer(cfg)
if checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
checkpoint = trainer.checkpointer.resume_or_load(path)
if checkpoint.checkpointer.has_checkpoint():
trainer.start_iter = checkpoint.get("iteration", -1) + 1
# Regular model training
trainer.train()
def main(args):
cfg = setup(args)
search_space = CS.ConfigurationSpace()
search_space.add_hyperparameters([
CS.UniformFloatHyperparameter(name="lr", lower=1e-6, upper=1e-3),
CS.UniformFloatHyperparameter(name="wd", lower=0, upper=1e-3),
CS.UniformFloatHyperparameter(name="wd_bias", lower=0, upper=1e-3),
CS.CategoricalHyperparameter(name="bsz", choices=[64, 96, 128, 160, 224, 256]),
CS.CategoricalHyperparameter(name="num_inst", choices=[2, 4, 8, 16, 32]),
CS.UniformIntegerHyperparameter(name="delay_iters", lower=20, upper=60),
CS.UniformFloatHyperparameter(name="ce_scale", lower=0.1, upper=1.0),
CS.UniformIntegerHyperparameter(name="circle_scale", lower=8, upper=256),
CS.UniformFloatHyperparameter(name="circle_margin", lower=0.1, upper=0.5),
CS.CategoricalHyperparameter(name="autoaug_enabled", choices=[True, False]),
CS.CategoricalHyperparameter(name="cj_enabled", choices=[True, False]),
]
)
exp_metrics = dict(metric="score", mode="max")
bohb_hyperband = HyperBandForBOHB(
time_attr="training_iteration",
max_t=7,
**exp_metrics,
)
bohb_search = TuneBOHB(
search_space, max_concurrent=4, **exp_metrics)
reporter = CLIReporter(
parameter_columns=["bsz", "num_inst", "lr"],
metric_columns=["r1", "map", "training_iteration"])
analysis = tune.run(
partial(
train_reid_tune,
cfg),
resources_per_trial={"cpu": 10, "gpu": 1},
search_alg=bohb_search,
num_samples=args.num_samples,
scheduler=bohb_hyperband,
progress_reporter=reporter,
local_dir=cfg.OUTPUT_DIR,
keep_checkpoints_num=4,
name="bohb")
best_trial = analysis.get_best_trial("map", "max", "last")
logger.info("Best trial config: {}".format(best_trial.config))
logger.info("Best trial final validation mAP: {}, Rank-1: {}".format(
best_trial.last_result["map"], best_trial.last_result["r1"]))
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--num-samples", type=int, default=20, help="number of tune trials")
args = parser.parse_args()
print("Command Line Args:", args)
main(args)