fast-reid/fastreid/solver/build.py

67 lines
1.9 KiB
Python
Raw Normal View History

2020-02-10 07:38:56 +08:00
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
2020-03-25 10:58:26 +08:00
from . import lr_scheduler
from . import optim
2020-02-10 07:38:56 +08:00
def build_optimizer(cfg, model):
params = []
for key, value in model.named_parameters():
2020-09-23 19:32:40 +08:00
if not value.requires_grad: continue
2020-02-10 07:38:56 +08:00
lr = cfg.SOLVER.BASE_LR
weight_decay = cfg.SOLVER.WEIGHT_DECAY
if "heads" in key:
lr *= cfg.SOLVER.HEADS_LR_FACTOR
2020-02-10 07:38:56 +08:00
if "bias" in key:
lr *= cfg.SOLVER.BIAS_LR_FACTOR
2020-02-10 07:38:56 +08:00
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
params += [{"name": key, "params": [value], "lr": lr, "weight_decay": weight_decay}]
2020-03-25 10:58:26 +08:00
solver_opt = cfg.SOLVER.OPT
2020-09-23 19:32:40 +08:00
# fmt: off
if solver_opt == "SGD": opt_fns = getattr(optim, solver_opt)(params, momentum=cfg.SOLVER.MOMENTUM)
else: opt_fns = getattr(optim, solver_opt)(params)
# fmt: on
2020-02-10 07:38:56 +08:00
return opt_fns
def build_lr_scheduler(cfg, optimizer):
scheduler_dict = {}
if cfg.SOLVER.WARMUP_ITERS > 0:
warmup_args = {
"optimizer": optimizer,
# warmup options
"warmup_factor": cfg.SOLVER.WARMUP_FACTOR,
"warmup_iters": cfg.SOLVER.WARMUP_ITERS,
"warmup_method": cfg.SOLVER.WARMUP_METHOD,
}
scheduler_dict["warmup_sched"] = lr_scheduler.WarmupLR(**warmup_args)
scheduler_args = {
"MultiStepLR": {
"optimizer": optimizer,
# multi-step lr scheduler options
"milestones": cfg.SOLVER.STEPS,
"gamma": cfg.SOLVER.GAMMA,
},
"CosineAnnealingLR": {
"optimizer": optimizer,
# cosine annealing lr scheduler options
"T_max": cfg.SOLVER.MAX_EPOCH,
"eta_min": cfg.SOLVER.ETA_MIN_LR,
},
}
scheduler_dict["lr_sched"] = getattr(lr_scheduler, cfg.SOLVER.SCHED)(
**scheduler_args[cfg.SOLVER.SCHED])
return scheduler_dict