# encoding: utf-8 """ @author: sherlock @contact: sherlockliao01@gmail.com """ import torch def make_optimizer(cfg, model): params = [] for key, value in model.named_parameters(): if not value.requires_grad: continue lr = cfg.SOLVER.BASE_LR weight_decay = cfg.SOLVER.WEIGHT_DECAY if "bias" in key: lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) else: optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) return optimizer def make_optimizer_with_center(cfg, model, center_criterion): params = [] for key, value in model.named_parameters(): if not value.requires_grad: continue lr = cfg.SOLVER.BASE_LR weight_decay = cfg.SOLVER.WEIGHT_DECAY if "bias" in key: lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) else: optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR) return optimizer, optimizer_center