From 07b8251ccb9f296e4be5d49fc98efebaa12d679f Mon Sep 17 00:00:00 2001 From: liaoxingyu Date: Mon, 31 May 2021 17:11:37 +0800 Subject: [PATCH] Support gradient clip Follow detectron2's instruction and add gradient clip in step function of optimizer --- configs/Base-SBS.yml | 2 - configs/Base-bagtricks.yml | 2 - fastreid/solver/build.py | 293 ++++++++++++++++++++++++++++++++++--- 3 files changed, 275 insertions(+), 22 deletions(-) diff --git a/configs/Base-SBS.yml b/configs/Base-SBS.yml index d61e7cc..33fe81a 100644 --- a/configs/Base-SBS.yml +++ b/configs/Base-SBS.yml @@ -42,9 +42,7 @@ SOLVER: OPT: Adam MAX_EPOCH: 60 BASE_LR: 0.00035 - BIAS_LR_FACTOR: 1. WEIGHT_DECAY: 0.0005 - WEIGHT_DECAY_BIAS: 0.0005 IMS_PER_BATCH: 64 SCHED: CosineAnnealingLR diff --git a/configs/Base-bagtricks.yml b/configs/Base-bagtricks.yml index acab5f1..f2cbe5d 100644 --- a/configs/Base-bagtricks.yml +++ b/configs/Base-bagtricks.yml @@ -56,9 +56,7 @@ SOLVER: OPT: Adam MAX_EPOCH: 120 BASE_LR: 0.00035 - BIAS_LR_FACTOR: 2. WEIGHT_DECAY: 0.0005 - WEIGHT_DECAY_BIAS: 0.0005 IMS_PER_BATCH: 64 SCHED: MultiStepLR diff --git a/fastreid/solver/build.py b/fastreid/solver/build.py index 806aa7e..e6245e8 100644 --- a/fastreid/solver/build.py +++ b/fastreid/solver/build.py @@ -4,36 +4,293 @@ @contact: sherlockliao01@gmail.com """ -import math +# Based on: https://github.com/facebookresearch/detectron2/blob/master/detectron2/solver/build.py +import copy +import itertools +import math +from enum import Enum +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union + +import torch + +from fastreid.config import CfgNode from . import lr_scheduler -from . import optim + +_GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]] +_GradientClipper = Callable[[_GradientClipperInput], None] + + +class GradientClipType(Enum): + VALUE = "value" + NORM = "norm" + + +def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper: + """ + Creates gradient clipping closure to clip by value or by norm, + according to the provided config. + """ + cfg = copy.deepcopy(cfg) + + def clip_grad_norm(p: _GradientClipperInput): + torch.nn.utils.clip_grad_norm_(p, cfg.CLIP_VALUE, cfg.NORM_TYPE) + + def clip_grad_value(p: _GradientClipperInput): + torch.nn.utils.clip_grad_value_(p, cfg.CLIP_VALUE) + + _GRADIENT_CLIP_TYPE_TO_CLIPPER = { + GradientClipType.VALUE: clip_grad_value, + GradientClipType.NORM: clip_grad_norm, + } + return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(cfg.CLIP_TYPE)] + + +def _generate_optimizer_class_with_gradient_clipping( + optimizer: Type[torch.optim.Optimizer], + *, + per_param_clipper: Optional[_GradientClipper] = None, + global_clipper: Optional[_GradientClipper] = None, +) -> Type[torch.optim.Optimizer]: + """ + Dynamically creates a new type that inherits the type of a given instance + and overrides the `step` method to add gradient clipping + """ + assert ( + per_param_clipper is None or global_clipper is None + ), "Not allowed to use both per-parameter clipping and global clipping" + + def optimizer_wgc_step(self, closure=None): + if per_param_clipper is not None: + for group in self.param_groups: + for p in group["params"]: + per_param_clipper(p) + else: + # global clipper for future use with detr + # (https://github.com/facebookresearch/detr/pull/287) + all_params = itertools.chain(*[g["params"] for g in self.param_groups]) + global_clipper(all_params) + optimizer.step(self, closure) + + OptimizerWithGradientClip = type( + optimizer.__name__ + "WithGradientClip", + (optimizer,), + {"step": optimizer_wgc_step}, + ) + return OptimizerWithGradientClip + + +def maybe_add_gradient_clipping( + cfg: CfgNode, optimizer: Type[torch.optim.Optimizer] +) -> Type[torch.optim.Optimizer]: + """ + If gradient clipping is enabled through config options, wraps the existing + optimizer type to become a new dynamically created class OptimizerWithGradientClip + that inherits the given optimizer and overrides the `step` method to + include gradient clipping. + Args: + cfg: CfgNode, configuration options + optimizer: type. A subclass of torch.optim.Optimizer + Return: + type: either the input `optimizer` (if gradient clipping is disabled), or + a subclass of it with gradient clipping included in the `step` method. + """ + if not cfg.SOLVER.CLIP_GRADIENTS.ENABLED: + return optimizer + if isinstance(optimizer, torch.optim.Optimizer): + optimizer_type = type(optimizer) + else: + assert issubclass(optimizer, torch.optim.Optimizer), optimizer + optimizer_type = optimizer + + grad_clipper = _create_gradient_clipper(cfg.SOLVER.CLIP_GRADIENTS) + OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping( + optimizer_type, per_param_clipper=grad_clipper + ) + if isinstance(optimizer, torch.optim.Optimizer): + optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended + return optimizer + else: + return OptimizerWithGradientClip + + +def _generate_optimizer_class_with_freeze_layer( + optimizer: Type[torch.optim.Optimizer], + *, + freeze_layers: Optional[List] = None, + freeze_iters: int = 0, +) -> Type[torch.optim.Optimizer]: + assert ( + freeze_layers is not None and freeze_iters > 0 + ), "No layers need to be frozen or freeze iterations is 0" + + cnt = 0 + + def optimizer_wfl_step(self, closure=None): + nonlocal cnt + if cnt < freeze_iters: + cnt += 1 + for group in self.param_groups: + if group["name"].split('.')[0] in freeze_layers: + for p in group["params"]: + if p.grad is not None: + p.grad = None + + optimizer.step(self, closure) + + OptimizerWithFreezeLayer = type( + optimizer.__name__ + "WithFreezeLayer", + (optimizer,), + {"step": optimizer_wfl_step}, + ) + return OptimizerWithFreezeLayer + + +def maybe_add_freeze_layer( + cfg: CfgNode, optimizer: Type[torch.optim.Optimizer] +) -> Type[torch.optim.Optimizer]: + if cfg.MODEL.FREEZE_LAYERS == [''] or cfg.SOLVER.FREEZE_ITERS == 0: + return optimizer + + if isinstance(optimizer, torch.optim.Optimizer): + optimizer_type = type(optimizer) + else: + assert issubclass(optimizer, torch.optim.Optimizer), optimizer + optimizer_type = optimizer + + OptimizerWithFreezeLayer = _generate_optimizer_class_with_freeze_layer( + optimizer_type, + freeze_layers=cfg.MODEL.FREEZE_LAYERS, + freeze_iters=cfg.SOLVER.FREEZE_ITERS + ) + if isinstance(optimizer, torch.optim.Optimizer): + optimizer.__class__ = OptimizerWithFreezeLayer # a bit hacky, not recommended + return optimizer + else: + return OptimizerWithFreezeLayer def build_optimizer(cfg, model): - params = [] - for key, value in model.named_parameters(): - if not value.requires_grad: continue - - lr = cfg.SOLVER.BASE_LR - weight_decay = cfg.SOLVER.WEIGHT_DECAY - if "heads" in key: - lr *= cfg.SOLVER.HEADS_LR_FACTOR - if "bias" in key: - lr *= cfg.SOLVER.BIAS_LR_FACTOR - weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS - params += [{"name": key, "params": [value], "lr": lr, "weight_decay": weight_decay}] + params = get_default_optimizer_params( + model, + base_lr=cfg.SOLVER.BASE_LR, + weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, + bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, + heads_lr_factor=cfg.SOLVER.HEADS_LR_FACTOR, + weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS + ) solver_opt = cfg.SOLVER.OPT if solver_opt == "SGD": - opt_fns = getattr(optim, solver_opt)( + return maybe_add_freeze_layer( + cfg, + maybe_add_gradient_clipping(cfg, torch.optim.SGD) + )( params, + lr=cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, - nesterov=True if cfg.SOLVER.MOMENTUM and cfg.SOLVER.NESTEROV else False + nesterov=cfg.SOLVER.NESTEROV, + weight_decay=cfg.SOLVER.WEIGHT_DECAY, ) else: - opt_fns = getattr(optim, solver_opt)(params) - return opt_fns + return maybe_add_freeze_layer( + cfg, + maybe_add_gradient_clipping(cfg, getattr(torch.optim, solver_opt)) + )( + params, + lr=cfg.SOLVER.BASE_LR, + weight_decay=cfg.SOLVER.WEIGHT_DECAY, + ) + + +def get_default_optimizer_params( + model: torch.nn.Module, + base_lr: Optional[float] = None, + weight_decay: Optional[float] = None, + weight_decay_norm: Optional[float] = None, + bias_lr_factor: Optional[float] = 1.0, + heads_lr_factor: Optional[float] = 1.0, + weight_decay_bias: Optional[float] = None, + overrides: Optional[Dict[str, Dict[str, float]]] = None, +): + """ + Get default param list for optimizer, with support for a few types of + overrides. If no overrides needed, this is equivalent to `model.parameters()`. + Args: + base_lr: lr for every group by default. Can be omitted to use the one in optimizer. + weight_decay: weight decay for every group by default. Can be omitted to use the one + in optimizer. + weight_decay_norm: override weight decay for params in normalization layers + bias_lr_factor: multiplier of lr for bias parameters. + heads_lr_factor: multiplier of lr for model.head parameters. + weight_decay_bias: override weight decay for bias parameters + overrides: if not `None`, provides values for optimizer hyperparameters + (LR, weight decay) for module parameters with a given name; e.g. + ``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and + weight decay values for all module parameters named `embedding`. + For common detection models, ``weight_decay_norm`` is the only option + needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings + from Detectron1 that are not found useful. + Example: + :: + torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0), + lr=0.01, weight_decay=1e-4, momentum=0.9) + """ + if overrides is None: + overrides = {} + defaults = {} + if base_lr is not None: + defaults["lr"] = base_lr + if weight_decay is not None: + defaults["weight_decay"] = weight_decay + bias_overrides = {} + if bias_lr_factor is not None and bias_lr_factor != 1.0: + # NOTE: unlike Detectron v1, we now by default make bias hyperparameters + # exactly the same as regular weights. + if base_lr is None: + raise ValueError("bias_lr_factor requires base_lr") + bias_overrides["lr"] = base_lr * bias_lr_factor + if weight_decay_bias is not None: + bias_overrides["weight_decay"] = weight_decay_bias + if len(bias_overrides): + if "bias" in overrides: + raise ValueError("Conflicting overrides for 'bias'") + overrides["bias"] = bias_overrides + + norm_module_types = ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.SyncBatchNorm, + # NaiveSyncBatchNorm inherits from BatchNorm2d + torch.nn.GroupNorm, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.LayerNorm, + torch.nn.LocalResponseNorm, + ) + params: List[Dict[str, Any]] = [] + memo: Set[torch.nn.parameter.Parameter] = set() + + for module_name, module in model.named_modules(): + for module_param_name, value in module.named_parameters(recurse=False): + if not value.requires_grad: + continue + # Avoid duplicating parameters + if value in memo: + continue + memo.add(value) + + hyperparams = copy.copy(defaults) + if isinstance(module, norm_module_types) and weight_decay_norm is not None: + hyperparams["weight_decay"] = weight_decay_norm + hyperparams.update(overrides.get(module_param_name, {})) + if module_name.split('.')[0] == "heads" and (heads_lr_factor is not None and heads_lr_factor != 1.0): + hyperparams["lr"] = hyperparams.get("lr", base_lr) * heads_lr_factor + params.append({"name": module_name + '.' + module_param_name, + "params": [value], **hyperparams}) + return params def build_lr_scheduler(cfg, optimizer, iters_per_epoch):