[Feature] Support LoRA. (#1687)
* [Feature] Support LoRA * [Feature] Support LoRA * [Fix] Fix bugs * [Refactor] Add copyright * [Fix] Fix bugs * [Enhancement] Add * [Fix] Fix bugs * [Fix] Fix bugs * [Fix] Fix bugs * [Fix] Fix bugs * [Fix] Fix bugs * [Docs] Update docstring * [Docs] Update docstring * [Refactor] Reformat with yapf * [Docs] Update docstring * [Refactor] Docformat * [Refactor] Fix double-quote-string * [Fix] fix pytorch version * [Fix] isort * [Fix] isort * [Enhancement] Extend forward * [Enhancement] Extend test * [Fix] Fix targets * [Enhancement] Extend LoRA to frozen models * [Fix] Fix spelling * [Fix] Override __getattr__ * [Fix] Add init_cfg * [Enhancement] Add example config * [Fix] Fix init_cfg * [Enhancement] Add merging script * [Fix] Remove init_cfg * [Fix] Change lora key * [Fix] Fix merge scripts * [Fix] Fix merge scripts * [Docs] Add docs * [Fix] fixpull/1554/merge
parent
569324b180
commit
64c446d507
|
@ -0,0 +1,84 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_pil_resize.py',
|
||||
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# model setting
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='LoRAModel',
|
||||
module=dict(
|
||||
type='VisionTransformer',
|
||||
arch='b',
|
||||
img_size=384,
|
||||
patch_size=16,
|
||||
drop_rate=0.1,
|
||||
init_cfg=dict(type='Pretrained', checkpoint='',
|
||||
prefix='backbone')),
|
||||
alpha=16,
|
||||
rank=16,
|
||||
drop_rate=0.1,
|
||||
targets=[dict(type='qkv')]),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='VisionTransformerClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1,
|
||||
mode='classy_vision'),
|
||||
init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)],
|
||||
))
|
||||
|
||||
# dataset setting
|
||||
data_preprocessor = dict(
|
||||
mean=[127.5, 127.5, 127.5],
|
||||
std=[127.5, 127.5, 127.5],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', scale=384, backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='ResizeEdge', scale=384, edge='short', backend='pillow'),
|
||||
dict(type='CenterCrop', crop_size=384),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=5,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=45,
|
||||
by_epoch=True,
|
||||
begin=5,
|
||||
end=50,
|
||||
eta_min=1e-6,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
train_cfg = dict(by_epoch=True, max_epochs=50)
|
||||
default_hooks = dict(
|
||||
# save checkpoint per epoch.
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
|
||||
|
||||
# schedule setting
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))
|
|
@ -15,6 +15,7 @@ The ``models`` package contains several sub-packages for addressing the differen
|
|||
- :mod:`~mmpretrain.models.necks`: The component between backbones and heads, e.g., GlobalAveragePooling.
|
||||
- :mod:`~mmpretrain.models.heads`: The component for specific tasks.
|
||||
- :mod:`~mmpretrain.models.losses`: Loss functions.
|
||||
- :mod:`~mmpretrain.models.peft`: The PEFT (Parameter-Efficient Fine-Tuning) module, e.g. LoRAModel.
|
||||
- :mod:`~mmpretrain.models.utils`: Some helper functions and common components used in various networks.
|
||||
|
||||
- :mod:`~mmpretrain.models.utils.data_preprocessor`: The component before model to preprocess the inputs, e.g., ClsDataPreprocessor.
|
||||
|
@ -306,6 +307,17 @@ Losses
|
|||
SeesawLoss
|
||||
SwAVLoss
|
||||
|
||||
.. module:: mmpretrain.models.peft
|
||||
|
||||
PEFT
|
||||
------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
LoRAModel
|
||||
|
||||
.. module:: mmpretrain.models.utils
|
||||
|
||||
models.utils
|
||||
|
|
|
@ -8,6 +8,7 @@ from .heads import * # noqa: F401,F403
|
|||
from .losses import * # noqa: F401,F403
|
||||
from .multimodal import * # noqa: F401,F403
|
||||
from .necks import * # noqa: F401,F403
|
||||
from .peft import * # noqa: F401,F403
|
||||
from .retrievers import * # noqa: F401,F403
|
||||
from .selfsup import * # noqa: F401,F403
|
||||
from .tta import * # noqa: F401,F403
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .lora import LoRAModel
|
||||
|
||||
__all__ = [
|
||||
'LoRAModel',
|
||||
]
|
|
@ -0,0 +1,205 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import re
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.model import BaseModule
|
||||
from torch import nn
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
r"""Implements LoRA in a linear layer.
|
||||
|
||||
Args:
|
||||
original_layer (nn.Linear): The linear layer to be finetuned.
|
||||
alpha (int): The scale factor of LoRA. Defaults to 1.
|
||||
rank (int): The rank of LoRA. Defaults to 0.
|
||||
drop_rate (float): The drop out rate for LoRA. Defaults to 0.
|
||||
|
||||
Note:
|
||||
The forward process of LoRA linear layer is:
|
||||
|
||||
.. math::
|
||||
`y = W_0 x + BAx * (\alpha / r)`
|
||||
|
||||
Where :math:`x` is the input, :math:`y` is the output,
|
||||
:math:`W_0` is the parameter of the original layer,
|
||||
:math:`A` and :math:`B` are the low-rank decomposition matrixs,
|
||||
:math: `\alpha` is the scale factor and :math: `r` is the rank.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
original_layer: nn.Linear,
|
||||
alpha: int = 1,
|
||||
rank: int = 0,
|
||||
drop_rate: float = 0.):
|
||||
super(LoRALinear, self).__init__()
|
||||
in_features = original_layer.in_features
|
||||
out_features = original_layer.out_features
|
||||
|
||||
self.lora_dropout = nn.Dropout(drop_rate)
|
||||
self.lora_down = nn.Linear(in_features, rank, bias=False)
|
||||
self.lora_up = nn.Linear(rank, out_features, bias=False)
|
||||
self.scaling = alpha / rank
|
||||
|
||||
nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_up.weight)
|
||||
|
||||
self.original_layer = original_layer
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
out = self.original_layer(x)
|
||||
|
||||
lora_x = self.lora_dropout(x)
|
||||
lora_out = self.lora_up(self.lora_down(lora_x)) * self.scaling
|
||||
|
||||
return out + lora_out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class LoRAModel(BaseModule):
|
||||
"""Implements LoRA in a module.
|
||||
|
||||
An PyTorch implement of : `LoRA: Low-Rank Adaptation
|
||||
of Large Language Models <https://arxiv.org/abs/2106.09685>`_
|
||||
|
||||
Args:
|
||||
module (dict): The config of the module to be finetuned. See
|
||||
:mod:`mmpretrain.models`
|
||||
alpha (int): The scale factor of LoRA. Defaults to 1.
|
||||
rank (int): The rank of LoRA. Defaults to 0.
|
||||
drop_rate (float): The drop out rate for LoRA. Defaults to 0.
|
||||
targets (List[dict]): The target layers to be applied with the LoRA.
|
||||
Defaults to a empty list. Specify by regular expression or suffix.
|
||||
|
||||
Examples:
|
||||
>>> model = LoRAModel(
|
||||
... module=dict(type='VisionTransformer', arch='b'),
|
||||
... alpha=4,
|
||||
... rank=4,
|
||||
... drop_rate=0.1,
|
||||
... targets=[
|
||||
... dict(type='.*qkv'), # regular expression
|
||||
... dict(type='proj', alpha=8, rank=8), # suffix
|
||||
... ])
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: dict,
|
||||
alpha: int = 1,
|
||||
rank: int = 0,
|
||||
drop_rate: float = 0.,
|
||||
targets: List[dict] = list()):
|
||||
|
||||
super().__init__()
|
||||
|
||||
module = MODELS.build(module)
|
||||
module.init_weights()
|
||||
|
||||
self.module = module
|
||||
self.alpha = alpha
|
||||
self.rank = rank
|
||||
self.drop_rate = drop_rate
|
||||
|
||||
assert len(targets) != 0, \
|
||||
'The length of target layers should not be 0.'
|
||||
|
||||
self.targets = targets
|
||||
|
||||
self.applied = False
|
||||
self.apply_lora()
|
||||
|
||||
if not self.applied:
|
||||
raise ValueError(
|
||||
'No lora layer is replaced. Please check targets.')
|
||||
|
||||
self._set_lora_trainable()
|
||||
self._register_state_dict_hooks()
|
||||
|
||||
def apply_lora(self):
|
||||
"""Apply LoRA to target layers."""
|
||||
module_names = [k for k, _ in self.module.named_modules()]
|
||||
for module_name in module_names:
|
||||
for target in self.targets:
|
||||
target_name = target['type']
|
||||
target_alpha = target.get('alpha', self.alpha)
|
||||
target_rank = target.get('rank', self.rank)
|
||||
target_drop_rate = target.get('drop_rate', self.drop_rate)
|
||||
|
||||
if re.fullmatch(target_name, module_name) or \
|
||||
module_name.endswith(target_name):
|
||||
current_module = self.module.get_submodule(module_name)
|
||||
if isinstance(current_module, nn.Linear):
|
||||
print_log(
|
||||
f'Set LoRA for {module_name} '
|
||||
f'with alpha: {target_alpha}, '
|
||||
f'rank: {target_rank}, '
|
||||
f'drop rate: {target_drop_rate}',
|
||||
logger='current')
|
||||
|
||||
self._replace_module(module_name, current_module,
|
||||
target_alpha, target_rank,
|
||||
target_drop_rate)
|
||||
self.applied = True
|
||||
|
||||
def _replace_module(self, module_name: str, current_module: nn.Module,
|
||||
alpha: int, rank: int, drop_rate: float):
|
||||
"""Replace target layer with LoRA linear layer in place."""
|
||||
parent_module_name = '.'.join(module_name.split('.')[:-1])
|
||||
parent_module = self.module.get_submodule(parent_module_name)
|
||||
|
||||
target_name = module_name.split('.')[-1]
|
||||
target_module = LoRALinear(current_module, alpha, rank, drop_rate)
|
||||
setattr(parent_module, target_name, target_module)
|
||||
|
||||
def _set_lora_trainable(self):
|
||||
"""Set only the lora parameters trainable."""
|
||||
for name, param in self.named_parameters():
|
||||
if '.lora_' in name:
|
||||
param.requires_grad = True
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
def _register_state_dict_hooks(self):
|
||||
"""Register state dict hooks.
|
||||
|
||||
Register state dict saving hooks to save only the lora parameters to
|
||||
the state dict. And register state dict loading hooks to handle the
|
||||
incompatible keys while loading the state dict.
|
||||
"""
|
||||
|
||||
def _state_dict_hook(module, state_dict, prefix, local_metadata):
|
||||
"""Save only the lora parameters to the state dict."""
|
||||
keys = [k for k, _ in state_dict.items()]
|
||||
for key in keys:
|
||||
if '.lora_' not in key:
|
||||
state_dict.pop(key)
|
||||
|
||||
self._register_state_dict_hook(_state_dict_hook)
|
||||
|
||||
def _load_state_dict_post_hook(module, incompatible_keys):
|
||||
"""Handle the incompatible keys while loading the state dict."""
|
||||
missing_keys = incompatible_keys.missing_keys.copy()
|
||||
for key in missing_keys:
|
||||
if '.lora_' not in key:
|
||||
incompatible_keys.missing_keys.remove(key)
|
||||
|
||||
unexpected_keys = incompatible_keys.unexpected_keys.copy()
|
||||
for key in unexpected_keys:
|
||||
if '.lora_' not in key:
|
||||
incompatible_keys.unexpected_keys.remove(key)
|
||||
|
||||
self.register_load_state_dict_post_hook(_load_state_dict_post_hook)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return super(LoRAModel, self).__getattr__(name)
|
||||
except AttributeError:
|
||||
return self.module.__getattribute__(name)
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import re
|
||||
|
||||
import pytest
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
||||
from mmpretrain.models.peft import LoRAModel
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
digit_version(TORCH_VERSION) < digit_version('1.9.0'),
|
||||
reason='get_submodule requires torch >= 1.9.0')
|
||||
def test_lora_backbone():
|
||||
module = dict(
|
||||
type='VisionTransformer',
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_path_rate=0.1,
|
||||
out_type='avg_featmap',
|
||||
final_norm=False)
|
||||
|
||||
lora_cfg = dict(
|
||||
module=module,
|
||||
alpha=1,
|
||||
rank=4,
|
||||
drop_rate=0.1,
|
||||
targets=[
|
||||
dict(type='qkv'),
|
||||
dict(type='.*proj', alpha=2, rank=2, drop_rate=0.2),
|
||||
])
|
||||
|
||||
lora_model = LoRAModel(**lora_cfg)
|
||||
|
||||
# test replace module
|
||||
for name, module in lora_model.named_modules():
|
||||
if name.endswith('qkv'):
|
||||
assert module.scaling == 0.25
|
||||
if re.fullmatch('.*proj', name):
|
||||
assert module.scaling == 1
|
||||
|
||||
# test freeze module
|
||||
for name, param in lora_model.named_parameters():
|
||||
if 'lora_' in name:
|
||||
assert param.requires_grad
|
||||
else:
|
||||
assert not param.requires_grad
|
||||
|
||||
# test get state dict
|
||||
state_dict = lora_model.state_dict()
|
||||
assert len(state_dict) != 0
|
||||
for name, param in state_dict.items():
|
||||
assert 'lora_' in name
|
||||
|
||||
# test load state dict
|
||||
incompatible_keys = lora_model.load_state_dict(state_dict, strict=True)
|
||||
assert str(incompatible_keys) == '<All keys matched successfully>'
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
digit_version(TORCH_VERSION) < digit_version('1.9.0'),
|
||||
reason='get_submodule requires torch >= 1.9.0')
|
||||
def test_lora_model():
|
||||
module = dict(
|
||||
type='MAE',
|
||||
backbone=dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75),
|
||||
neck=dict(
|
||||
type='MAEPretrainDecoder',
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
decoder_embed_dim=512,
|
||||
decoder_depth=8,
|
||||
decoder_num_heads=16,
|
||||
mlp_ratio=4.,
|
||||
),
|
||||
head=dict(
|
||||
type='MAEPretrainHead',
|
||||
norm_pix=True,
|
||||
patch_size=16,
|
||||
loss=dict(type='PixelReconstructionLoss', criterion='L2')),
|
||||
init_cfg=[
|
||||
dict(type='Xavier', layer='Linear', distribution='uniform'),
|
||||
dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)
|
||||
])
|
||||
|
||||
lora_cfg = dict(
|
||||
module=module,
|
||||
alpha=1,
|
||||
rank=4,
|
||||
drop_rate=0.1,
|
||||
targets=[
|
||||
dict(type='qkv'),
|
||||
dict(type='.*proj', alpha=2, rank=2, drop_rate=0.2),
|
||||
])
|
||||
|
||||
lora_model = LoRAModel(**lora_cfg)
|
||||
|
||||
# test replace module
|
||||
for name, module in lora_model.named_modules():
|
||||
if name.endswith('qkv'):
|
||||
assert module.scaling == 0.25
|
||||
if re.fullmatch('.*proj', name):
|
||||
assert module.scaling == 1
|
||||
|
||||
# test freeze module
|
||||
for name, param in lora_model.named_parameters():
|
||||
if 'lora_' in name:
|
||||
assert param.requires_grad
|
||||
else:
|
||||
assert not param.requires_grad
|
||||
|
||||
# test get state dict
|
||||
state_dict = lora_model.state_dict()
|
||||
assert len(state_dict) != 0
|
||||
for name, param in state_dict.items():
|
||||
assert 'lora_' in name
|
||||
|
||||
# test load state dict
|
||||
incompatible_keys = lora_model.load_state_dict(state_dict, strict=True)
|
||||
assert str(incompatible_keys) == '<All keys matched successfully>'
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from mmengine.config import Config
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def merge_lora_weight(cfg, lora_weight):
|
||||
"""Merge base weight and lora weight.
|
||||
|
||||
Args:
|
||||
cfg (dict): config for LoRAModel.
|
||||
lora_weight (dict): weight dict from LoRAModel.
|
||||
Returns:
|
||||
Merged weight.
|
||||
"""
|
||||
temp = dict()
|
||||
mapping = dict()
|
||||
for name, param in lora_weight['state_dict'].items():
|
||||
# backbone.module.layers.11.attn.qkv.lora_down.weight
|
||||
if '.lora_' in name:
|
||||
lora_split = name.split('.')
|
||||
prefix = '.'.join(lora_split[:-2])
|
||||
if prefix not in mapping:
|
||||
mapping[prefix] = dict()
|
||||
lora_type = lora_split[-2]
|
||||
mapping[prefix][lora_type] = param
|
||||
else:
|
||||
temp[name] = param
|
||||
|
||||
model = MODELS.build(cfg['model'])
|
||||
for name, param in model.named_parameters():
|
||||
if name in temp or '.lora_' in name:
|
||||
continue
|
||||
else:
|
||||
name_split = name.split('.')
|
||||
prefix = prefix = '.'.join(name_split[:-2])
|
||||
if prefix in mapping:
|
||||
name_split.pop(-2)
|
||||
if name_split[-1] == 'weight':
|
||||
scaling = get_scaling(model, prefix)
|
||||
lora_down = mapping[prefix]['lora_down']
|
||||
lora_up = mapping[prefix]['lora_up']
|
||||
param += lora_up @ lora_down * scaling
|
||||
name_split.pop(1)
|
||||
name = '.'.join(name_split)
|
||||
temp[name] = param
|
||||
|
||||
result = dict()
|
||||
result['state_dict'] = temp
|
||||
result['meta'] = lora_weight['meta']
|
||||
return result
|
||||
|
||||
|
||||
def get_scaling(model, prefix):
|
||||
"""Get the scaling of target layer.
|
||||
|
||||
Args:
|
||||
model (LoRAModel): the LoRAModel.
|
||||
prefix (str): the prefix of the layer.
|
||||
Returns:
|
||||
the scale of the LoRALinear.
|
||||
"""
|
||||
prefix_split = prefix.split('.')
|
||||
for i in prefix_split:
|
||||
model = getattr(model, i)
|
||||
return model.scaling
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Merge LoRA weight')
|
||||
parser.add_argument('cfg', help='cfg path')
|
||||
parser.add_argument('src', help='src lora model path')
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
dst = Path(args.dst)
|
||||
if dst.suffix != '.pth':
|
||||
print('The path should contain the name of the pth format file.')
|
||||
exit(1)
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cfg = Config.fromfile(args.cfg)
|
||||
lora_model = torch.load(args.src, map_location='cpu')
|
||||
|
||||
merged_model = merge_lora_weight(cfg, lora_model)
|
||||
torch.save(merged_model, args.dst)
|
Loading…
Reference in New Issue