mirror of https://github.com/RE-OWOD/RE-OWOD
29 lines
903 B
Python
29 lines
903 B
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
import torch
|
|
|
|
from detectron2.config import CfgNode
|
|
from detectron2.solver import build_lr_scheduler as build_d2_lr_scheduler
|
|
|
|
from .lr_scheduler import WarmupPolyLR
|
|
|
|
|
|
def build_lr_scheduler(
|
|
cfg: CfgNode, optimizer: torch.optim.Optimizer
|
|
) -> torch.optim.lr_scheduler._LRScheduler:
|
|
"""
|
|
Build a LR scheduler from config.
|
|
"""
|
|
name = cfg.SOLVER.LR_SCHEDULER_NAME
|
|
if name == "WarmupPolyLR":
|
|
return WarmupPolyLR(
|
|
optimizer,
|
|
cfg.SOLVER.MAX_ITER,
|
|
warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
|
|
warmup_iters=cfg.SOLVER.WARMUP_ITERS,
|
|
warmup_method=cfg.SOLVER.WARMUP_METHOD,
|
|
power=cfg.SOLVER.POLY_LR_POWER,
|
|
constant_ending=cfg.SOLVER.POLY_LR_CONSTANT_ENDING,
|
|
)
|
|
else:
|
|
return build_d2_lr_scheduler(cfg, optimizer)
|