[Fix] Fix init_cfg
parent
7b4573cea0
commit
63343dc116
|
@ -31,7 +31,6 @@ model = dict(
|
|||
init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)],
|
||||
))
|
||||
|
||||
|
||||
# dataset setting
|
||||
data_preprocessor = dict(
|
||||
mean=[127.5, 127.5, 127.5],
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import re
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.model.weight_init import _initialize
|
||||
from torch import nn
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
@ -19,6 +20,8 @@ class LoRALinear(nn.Module):
|
|||
alpha (int): The scale factor of LoRA. Defaults to 1.
|
||||
rank (int): The rank of LoRA. Defaults to 0.
|
||||
drop_rate (float): The drop out rate for LoRA. Defaults to 0.
|
||||
init_cfg (Union[List[dict], dict], optional):
|
||||
Initialization config dict. Defaults to None.
|
||||
|
||||
Note:
|
||||
The forward process of LoRA linear layer is:
|
||||
|
@ -36,7 +39,8 @@ class LoRALinear(nn.Module):
|
|||
original_layer: nn.Linear,
|
||||
alpha: int = 1,
|
||||
rank: int = 0,
|
||||
drop_rate: float = 0.):
|
||||
drop_rate: float = 0.,
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None):
|
||||
super(LoRALinear, self).__init__()
|
||||
in_features = original_layer.in_features
|
||||
out_features = original_layer.out_features
|
||||
|
@ -46,10 +50,34 @@ class LoRALinear(nn.Module):
|
|||
self.lora_up = nn.Linear(rank, out_features, bias=False)
|
||||
self.scaling = alpha / rank
|
||||
|
||||
nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_down.weight)
|
||||
|
||||
self.original_layer = original_layer
|
||||
self.init_cfg = init_cfg
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
lora_down_init_cfg = dict(type='Kaiming', a=math.sqrt(5))
|
||||
lora_up_init_cfg = dict(type='Constant', val=0)
|
||||
|
||||
if self.init_cfg is None:
|
||||
init_cfg = []
|
||||
if isinstance(self.init_cfg, dict):
|
||||
init_cfg = [self.init_cfg]
|
||||
|
||||
is_lora_down_inited, is_lora_up_inited = False, False
|
||||
for cfg in init_cfg:
|
||||
name = cfg.pop('name', None)
|
||||
if name == 'lora_down':
|
||||
_initialize(self.lora_down, cfg, wholemodule=True)
|
||||
is_lora_down_inited = True
|
||||
elif name == 'lora_up':
|
||||
_initialize(self.lora_up, cfg, wholemodule=True)
|
||||
is_lora_up_inited = True
|
||||
|
||||
if not is_lora_down_inited:
|
||||
_initialize(self.lora_down, lora_down_init_cfg, wholemodule=True)
|
||||
if not is_lora_up_inited:
|
||||
_initialize(self.lora_up, lora_up_init_cfg, wholemodule=True)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
out = self.original_layer(x)
|
||||
|
@ -75,8 +103,8 @@ class LoRAModel(BaseModule):
|
|||
drop_rate (float): The drop out rate for LoRA. Defaults to 0.
|
||||
targets (List[dict]): The target layers to be applied with the LoRA.
|
||||
Defaults to a empty list.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
init_cfg (Union[List[dict], dict], optional):
|
||||
Initialization config dict. Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> model = LoRAModel(
|
||||
|
@ -87,6 +115,10 @@ class LoRAModel(BaseModule):
|
|||
... targets=[
|
||||
... dict(type='qkv'),
|
||||
... dict(type='.*proj', alpha=8, rank=8, drop_rate=0.2),
|
||||
... ],
|
||||
... init_cfg=[
|
||||
... dict(type='Kaiming', name='lora_down', a=math.sqrt(5)),
|
||||
... dict(type='Constant', name='lora_up', val=0)
|
||||
... ])
|
||||
"""
|
||||
|
||||
|
@ -96,9 +128,10 @@ class LoRAModel(BaseModule):
|
|||
rank: int = 0,
|
||||
drop_rate: float = 0.,
|
||||
targets: List[dict] = list(),
|
||||
init_cfg: Optional[dict] = None):
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None):
|
||||
|
||||
super().__init__(init_cfg)
|
||||
self.init_cfg = init_cfg
|
||||
|
||||
module = MODELS.build(module)
|
||||
module.init_weights()
|
||||
|
@ -156,7 +189,8 @@ class LoRAModel(BaseModule):
|
|||
parent_module = self.module.get_submodule(parent_module_name)
|
||||
|
||||
target_name = module_name.split('.')[-1]
|
||||
target_module = LoRALinear(current_module, alpha, rank, drop_rate)
|
||||
target_module = LoRALinear(current_module, alpha, rank, drop_rate,
|
||||
self.init_cfg)
|
||||
setattr(parent_module, target_name, target_module)
|
||||
|
||||
def _set_lora_trainable(self):
|
||||
|
|
Loading…
Reference in New Issue