[Fix] Fix init_cfg

pull/1687/head
fanqiNO1 2023-07-20 16:27:54 +08:00
parent 7b4573cea0
commit 63343dc116
2 changed files with 43 additions and 10 deletions

View File

@ -31,7 +31,6 @@ model = dict(
init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)],
))
# dataset setting
data_preprocessor = dict(
mean=[127.5, 127.5, 127.5],

View File

@ -1,11 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import re
from typing import Any, List, Optional
from typing import Any, List, Optional, Union
import torch
from mmengine.logging import print_log
from mmengine.model import BaseModule
from mmengine.model.weight_init import _initialize
from torch import nn
from mmpretrain.registry import MODELS
@ -19,6 +20,8 @@ class LoRALinear(nn.Module):
alpha (int): The scale factor of LoRA. Defaults to 1.
rank (int): The rank of LoRA. Defaults to 0.
drop_rate (float): The drop out rate for LoRA. Defaults to 0.
init_cfg (Union[List[dict], dict], optional):
Initialization config dict. Defaults to None.
Note:
The forward process of LoRA linear layer is:
@ -36,7 +39,8 @@ class LoRALinear(nn.Module):
original_layer: nn.Linear,
alpha: int = 1,
rank: int = 0,
drop_rate: float = 0.):
drop_rate: float = 0.,
init_cfg: Optional[Union[List[dict], dict]] = None):
super(LoRALinear, self).__init__()
in_features = original_layer.in_features
out_features = original_layer.out_features
@ -46,10 +50,34 @@ class LoRALinear(nn.Module):
self.lora_up = nn.Linear(rank, out_features, bias=False)
self.scaling = alpha / rank
nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_down.weight)
self.original_layer = original_layer
self.init_cfg = init_cfg
self.init_weights()
def init_weights(self):
lora_down_init_cfg = dict(type='Kaiming', a=math.sqrt(5))
lora_up_init_cfg = dict(type='Constant', val=0)
if self.init_cfg is None:
init_cfg = []
if isinstance(self.init_cfg, dict):
init_cfg = [self.init_cfg]
is_lora_down_inited, is_lora_up_inited = False, False
for cfg in init_cfg:
name = cfg.pop('name', None)
if name == 'lora_down':
_initialize(self.lora_down, cfg, wholemodule=True)
is_lora_down_inited = True
elif name == 'lora_up':
_initialize(self.lora_up, cfg, wholemodule=True)
is_lora_up_inited = True
if not is_lora_down_inited:
_initialize(self.lora_down, lora_down_init_cfg, wholemodule=True)
if not is_lora_up_inited:
_initialize(self.lora_up, lora_up_init_cfg, wholemodule=True)
def forward(self, x: torch.Tensor):
out = self.original_layer(x)
@ -75,8 +103,8 @@ class LoRAModel(BaseModule):
drop_rate (float): The drop out rate for LoRA. Defaults to 0.
targets (List[dict]): The target layers to be applied with the LoRA.
Defaults to a empty list.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
init_cfg (Union[List[dict], dict], optional):
Initialization config dict. Defaults to None.
Examples:
>>> model = LoRAModel(
@ -87,6 +115,10 @@ class LoRAModel(BaseModule):
... targets=[
... dict(type='qkv'),
... dict(type='.*proj', alpha=8, rank=8, drop_rate=0.2),
... ],
... init_cfg=[
... dict(type='Kaiming', name='lora_down', a=math.sqrt(5)),
... dict(type='Constant', name='lora_up', val=0)
... ])
"""
@ -96,9 +128,10 @@ class LoRAModel(BaseModule):
rank: int = 0,
drop_rate: float = 0.,
targets: List[dict] = list(),
init_cfg: Optional[dict] = None):
init_cfg: Optional[Union[List[dict], dict]] = None):
super().__init__(init_cfg)
self.init_cfg = init_cfg
module = MODELS.build(module)
module.init_weights()
@ -156,7 +189,8 @@ class LoRAModel(BaseModule):
parent_module = self.module.get_submodule(parent_module_name)
target_name = module_name.split('.')[-1]
target_module = LoRALinear(current_module, alpha, rank, drop_rate)
target_module = LoRALinear(current_module, alpha, rank, drop_rate,
self.init_cfg)
setattr(parent_module, target_name, target_module)
def _set_lora_trainable(self):