mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
Support gradient clip
Follow detectron2's instruction and add gradient clip in step function of optimizer
This commit is contained in:
parent
2cabc3428a
commit
07b8251ccb
@ -42,9 +42,7 @@ SOLVER:
|
|||||||
OPT: Adam
|
OPT: Adam
|
||||||
MAX_EPOCH: 60
|
MAX_EPOCH: 60
|
||||||
BASE_LR: 0.00035
|
BASE_LR: 0.00035
|
||||||
BIAS_LR_FACTOR: 1.
|
|
||||||
WEIGHT_DECAY: 0.0005
|
WEIGHT_DECAY: 0.0005
|
||||||
WEIGHT_DECAY_BIAS: 0.0005
|
|
||||||
IMS_PER_BATCH: 64
|
IMS_PER_BATCH: 64
|
||||||
|
|
||||||
SCHED: CosineAnnealingLR
|
SCHED: CosineAnnealingLR
|
||||||
|
@ -56,9 +56,7 @@ SOLVER:
|
|||||||
OPT: Adam
|
OPT: Adam
|
||||||
MAX_EPOCH: 120
|
MAX_EPOCH: 120
|
||||||
BASE_LR: 0.00035
|
BASE_LR: 0.00035
|
||||||
BIAS_LR_FACTOR: 2.
|
|
||||||
WEIGHT_DECAY: 0.0005
|
WEIGHT_DECAY: 0.0005
|
||||||
WEIGHT_DECAY_BIAS: 0.0005
|
|
||||||
IMS_PER_BATCH: 64
|
IMS_PER_BATCH: 64
|
||||||
|
|
||||||
SCHED: MultiStepLR
|
SCHED: MultiStepLR
|
||||||
|
@ -4,36 +4,293 @@
|
|||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
# Based on: https://github.com/facebookresearch/detectron2/blob/master/detectron2/solver/build.py
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
import math
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from fastreid.config import CfgNode
|
||||||
from . import lr_scheduler
|
from . import lr_scheduler
|
||||||
from . import optim
|
|
||||||
|
_GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]]
|
||||||
|
_GradientClipper = Callable[[_GradientClipperInput], None]
|
||||||
|
|
||||||
|
|
||||||
|
class GradientClipType(Enum):
|
||||||
|
VALUE = "value"
|
||||||
|
NORM = "norm"
|
||||||
|
|
||||||
|
|
||||||
|
def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper:
|
||||||
|
"""
|
||||||
|
Creates gradient clipping closure to clip by value or by norm,
|
||||||
|
according to the provided config.
|
||||||
|
"""
|
||||||
|
cfg = copy.deepcopy(cfg)
|
||||||
|
|
||||||
|
def clip_grad_norm(p: _GradientClipperInput):
|
||||||
|
torch.nn.utils.clip_grad_norm_(p, cfg.CLIP_VALUE, cfg.NORM_TYPE)
|
||||||
|
|
||||||
|
def clip_grad_value(p: _GradientClipperInput):
|
||||||
|
torch.nn.utils.clip_grad_value_(p, cfg.CLIP_VALUE)
|
||||||
|
|
||||||
|
_GRADIENT_CLIP_TYPE_TO_CLIPPER = {
|
||||||
|
GradientClipType.VALUE: clip_grad_value,
|
||||||
|
GradientClipType.NORM: clip_grad_norm,
|
||||||
|
}
|
||||||
|
return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(cfg.CLIP_TYPE)]
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_optimizer_class_with_gradient_clipping(
|
||||||
|
optimizer: Type[torch.optim.Optimizer],
|
||||||
|
*,
|
||||||
|
per_param_clipper: Optional[_GradientClipper] = None,
|
||||||
|
global_clipper: Optional[_GradientClipper] = None,
|
||||||
|
) -> Type[torch.optim.Optimizer]:
|
||||||
|
"""
|
||||||
|
Dynamically creates a new type that inherits the type of a given instance
|
||||||
|
and overrides the `step` method to add gradient clipping
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
per_param_clipper is None or global_clipper is None
|
||||||
|
), "Not allowed to use both per-parameter clipping and global clipping"
|
||||||
|
|
||||||
|
def optimizer_wgc_step(self, closure=None):
|
||||||
|
if per_param_clipper is not None:
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group["params"]:
|
||||||
|
per_param_clipper(p)
|
||||||
|
else:
|
||||||
|
# global clipper for future use with detr
|
||||||
|
# (https://github.com/facebookresearch/detr/pull/287)
|
||||||
|
all_params = itertools.chain(*[g["params"] for g in self.param_groups])
|
||||||
|
global_clipper(all_params)
|
||||||
|
optimizer.step(self, closure)
|
||||||
|
|
||||||
|
OptimizerWithGradientClip = type(
|
||||||
|
optimizer.__name__ + "WithGradientClip",
|
||||||
|
(optimizer,),
|
||||||
|
{"step": optimizer_wgc_step},
|
||||||
|
)
|
||||||
|
return OptimizerWithGradientClip
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_add_gradient_clipping(
|
||||||
|
cfg: CfgNode, optimizer: Type[torch.optim.Optimizer]
|
||||||
|
) -> Type[torch.optim.Optimizer]:
|
||||||
|
"""
|
||||||
|
If gradient clipping is enabled through config options, wraps the existing
|
||||||
|
optimizer type to become a new dynamically created class OptimizerWithGradientClip
|
||||||
|
that inherits the given optimizer and overrides the `step` method to
|
||||||
|
include gradient clipping.
|
||||||
|
Args:
|
||||||
|
cfg: CfgNode, configuration options
|
||||||
|
optimizer: type. A subclass of torch.optim.Optimizer
|
||||||
|
Return:
|
||||||
|
type: either the input `optimizer` (if gradient clipping is disabled), or
|
||||||
|
a subclass of it with gradient clipping included in the `step` method.
|
||||||
|
"""
|
||||||
|
if not cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
|
||||||
|
return optimizer
|
||||||
|
if isinstance(optimizer, torch.optim.Optimizer):
|
||||||
|
optimizer_type = type(optimizer)
|
||||||
|
else:
|
||||||
|
assert issubclass(optimizer, torch.optim.Optimizer), optimizer
|
||||||
|
optimizer_type = optimizer
|
||||||
|
|
||||||
|
grad_clipper = _create_gradient_clipper(cfg.SOLVER.CLIP_GRADIENTS)
|
||||||
|
OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping(
|
||||||
|
optimizer_type, per_param_clipper=grad_clipper
|
||||||
|
)
|
||||||
|
if isinstance(optimizer, torch.optim.Optimizer):
|
||||||
|
optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended
|
||||||
|
return optimizer
|
||||||
|
else:
|
||||||
|
return OptimizerWithGradientClip
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_optimizer_class_with_freeze_layer(
|
||||||
|
optimizer: Type[torch.optim.Optimizer],
|
||||||
|
*,
|
||||||
|
freeze_layers: Optional[List] = None,
|
||||||
|
freeze_iters: int = 0,
|
||||||
|
) -> Type[torch.optim.Optimizer]:
|
||||||
|
assert (
|
||||||
|
freeze_layers is not None and freeze_iters > 0
|
||||||
|
), "No layers need to be frozen or freeze iterations is 0"
|
||||||
|
|
||||||
|
cnt = 0
|
||||||
|
|
||||||
|
def optimizer_wfl_step(self, closure=None):
|
||||||
|
nonlocal cnt
|
||||||
|
if cnt < freeze_iters:
|
||||||
|
cnt += 1
|
||||||
|
for group in self.param_groups:
|
||||||
|
if group["name"].split('.')[0] in freeze_layers:
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is not None:
|
||||||
|
p.grad = None
|
||||||
|
|
||||||
|
optimizer.step(self, closure)
|
||||||
|
|
||||||
|
OptimizerWithFreezeLayer = type(
|
||||||
|
optimizer.__name__ + "WithFreezeLayer",
|
||||||
|
(optimizer,),
|
||||||
|
{"step": optimizer_wfl_step},
|
||||||
|
)
|
||||||
|
return OptimizerWithFreezeLayer
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_add_freeze_layer(
|
||||||
|
cfg: CfgNode, optimizer: Type[torch.optim.Optimizer]
|
||||||
|
) -> Type[torch.optim.Optimizer]:
|
||||||
|
if cfg.MODEL.FREEZE_LAYERS == [''] or cfg.SOLVER.FREEZE_ITERS == 0:
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
if isinstance(optimizer, torch.optim.Optimizer):
|
||||||
|
optimizer_type = type(optimizer)
|
||||||
|
else:
|
||||||
|
assert issubclass(optimizer, torch.optim.Optimizer), optimizer
|
||||||
|
optimizer_type = optimizer
|
||||||
|
|
||||||
|
OptimizerWithFreezeLayer = _generate_optimizer_class_with_freeze_layer(
|
||||||
|
optimizer_type,
|
||||||
|
freeze_layers=cfg.MODEL.FREEZE_LAYERS,
|
||||||
|
freeze_iters=cfg.SOLVER.FREEZE_ITERS
|
||||||
|
)
|
||||||
|
if isinstance(optimizer, torch.optim.Optimizer):
|
||||||
|
optimizer.__class__ = OptimizerWithFreezeLayer # a bit hacky, not recommended
|
||||||
|
return optimizer
|
||||||
|
else:
|
||||||
|
return OptimizerWithFreezeLayer
|
||||||
|
|
||||||
|
|
||||||
def build_optimizer(cfg, model):
|
def build_optimizer(cfg, model):
|
||||||
params = []
|
params = get_default_optimizer_params(
|
||||||
for key, value in model.named_parameters():
|
model,
|
||||||
if not value.requires_grad: continue
|
base_lr=cfg.SOLVER.BASE_LR,
|
||||||
|
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
|
||||||
lr = cfg.SOLVER.BASE_LR
|
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
|
||||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY
|
heads_lr_factor=cfg.SOLVER.HEADS_LR_FACTOR,
|
||||||
if "heads" in key:
|
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS
|
||||||
lr *= cfg.SOLVER.HEADS_LR_FACTOR
|
)
|
||||||
if "bias" in key:
|
|
||||||
lr *= cfg.SOLVER.BIAS_LR_FACTOR
|
|
||||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
|
|
||||||
params += [{"name": key, "params": [value], "lr": lr, "weight_decay": weight_decay}]
|
|
||||||
|
|
||||||
solver_opt = cfg.SOLVER.OPT
|
solver_opt = cfg.SOLVER.OPT
|
||||||
if solver_opt == "SGD":
|
if solver_opt == "SGD":
|
||||||
opt_fns = getattr(optim, solver_opt)(
|
return maybe_add_freeze_layer(
|
||||||
|
cfg,
|
||||||
|
maybe_add_gradient_clipping(cfg, torch.optim.SGD)
|
||||||
|
)(
|
||||||
params,
|
params,
|
||||||
|
lr=cfg.SOLVER.BASE_LR,
|
||||||
momentum=cfg.SOLVER.MOMENTUM,
|
momentum=cfg.SOLVER.MOMENTUM,
|
||||||
nesterov=True if cfg.SOLVER.MOMENTUM and cfg.SOLVER.NESTEROV else False
|
nesterov=cfg.SOLVER.NESTEROV,
|
||||||
|
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
opt_fns = getattr(optim, solver_opt)(params)
|
return maybe_add_freeze_layer(
|
||||||
return opt_fns
|
cfg,
|
||||||
|
maybe_add_gradient_clipping(cfg, getattr(torch.optim, solver_opt))
|
||||||
|
)(
|
||||||
|
params,
|
||||||
|
lr=cfg.SOLVER.BASE_LR,
|
||||||
|
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_optimizer_params(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
base_lr: Optional[float] = None,
|
||||||
|
weight_decay: Optional[float] = None,
|
||||||
|
weight_decay_norm: Optional[float] = None,
|
||||||
|
bias_lr_factor: Optional[float] = 1.0,
|
||||||
|
heads_lr_factor: Optional[float] = 1.0,
|
||||||
|
weight_decay_bias: Optional[float] = None,
|
||||||
|
overrides: Optional[Dict[str, Dict[str, float]]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get default param list for optimizer, with support for a few types of
|
||||||
|
overrides. If no overrides needed, this is equivalent to `model.parameters()`.
|
||||||
|
Args:
|
||||||
|
base_lr: lr for every group by default. Can be omitted to use the one in optimizer.
|
||||||
|
weight_decay: weight decay for every group by default. Can be omitted to use the one
|
||||||
|
in optimizer.
|
||||||
|
weight_decay_norm: override weight decay for params in normalization layers
|
||||||
|
bias_lr_factor: multiplier of lr for bias parameters.
|
||||||
|
heads_lr_factor: multiplier of lr for model.head parameters.
|
||||||
|
weight_decay_bias: override weight decay for bias parameters
|
||||||
|
overrides: if not `None`, provides values for optimizer hyperparameters
|
||||||
|
(LR, weight decay) for module parameters with a given name; e.g.
|
||||||
|
``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
|
||||||
|
weight decay values for all module parameters named `embedding`.
|
||||||
|
For common detection models, ``weight_decay_norm`` is the only option
|
||||||
|
needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings
|
||||||
|
from Detectron1 that are not found useful.
|
||||||
|
Example:
|
||||||
|
::
|
||||||
|
torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0),
|
||||||
|
lr=0.01, weight_decay=1e-4, momentum=0.9)
|
||||||
|
"""
|
||||||
|
if overrides is None:
|
||||||
|
overrides = {}
|
||||||
|
defaults = {}
|
||||||
|
if base_lr is not None:
|
||||||
|
defaults["lr"] = base_lr
|
||||||
|
if weight_decay is not None:
|
||||||
|
defaults["weight_decay"] = weight_decay
|
||||||
|
bias_overrides = {}
|
||||||
|
if bias_lr_factor is not None and bias_lr_factor != 1.0:
|
||||||
|
# NOTE: unlike Detectron v1, we now by default make bias hyperparameters
|
||||||
|
# exactly the same as regular weights.
|
||||||
|
if base_lr is None:
|
||||||
|
raise ValueError("bias_lr_factor requires base_lr")
|
||||||
|
bias_overrides["lr"] = base_lr * bias_lr_factor
|
||||||
|
if weight_decay_bias is not None:
|
||||||
|
bias_overrides["weight_decay"] = weight_decay_bias
|
||||||
|
if len(bias_overrides):
|
||||||
|
if "bias" in overrides:
|
||||||
|
raise ValueError("Conflicting overrides for 'bias'")
|
||||||
|
overrides["bias"] = bias_overrides
|
||||||
|
|
||||||
|
norm_module_types = (
|
||||||
|
torch.nn.BatchNorm1d,
|
||||||
|
torch.nn.BatchNorm2d,
|
||||||
|
torch.nn.BatchNorm3d,
|
||||||
|
torch.nn.SyncBatchNorm,
|
||||||
|
# NaiveSyncBatchNorm inherits from BatchNorm2d
|
||||||
|
torch.nn.GroupNorm,
|
||||||
|
torch.nn.InstanceNorm1d,
|
||||||
|
torch.nn.InstanceNorm2d,
|
||||||
|
torch.nn.InstanceNorm3d,
|
||||||
|
torch.nn.LayerNorm,
|
||||||
|
torch.nn.LocalResponseNorm,
|
||||||
|
)
|
||||||
|
params: List[Dict[str, Any]] = []
|
||||||
|
memo: Set[torch.nn.parameter.Parameter] = set()
|
||||||
|
|
||||||
|
for module_name, module in model.named_modules():
|
||||||
|
for module_param_name, value in module.named_parameters(recurse=False):
|
||||||
|
if not value.requires_grad:
|
||||||
|
continue
|
||||||
|
# Avoid duplicating parameters
|
||||||
|
if value in memo:
|
||||||
|
continue
|
||||||
|
memo.add(value)
|
||||||
|
|
||||||
|
hyperparams = copy.copy(defaults)
|
||||||
|
if isinstance(module, norm_module_types) and weight_decay_norm is not None:
|
||||||
|
hyperparams["weight_decay"] = weight_decay_norm
|
||||||
|
hyperparams.update(overrides.get(module_param_name, {}))
|
||||||
|
if module_name.split('.')[0] == "heads" and (heads_lr_factor is not None and heads_lr_factor != 1.0):
|
||||||
|
hyperparams["lr"] = hyperparams.get("lr", base_lr) * heads_lr_factor
|
||||||
|
params.append({"name": module_name + '.' + module_param_name,
|
||||||
|
"params": [value], **hyperparams})
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
def build_lr_scheduler(cfg, optimizer, iters_per_epoch):
|
def build_lr_scheduler(cfg, optimizer, iters_per_epoch):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user