[Feature] Support LoRA. (#1687)

* [Feature] Support LoRA

* [Feature] Support LoRA

* [Fix] Fix bugs

* [Refactor] Add copyright

* [Fix] Fix bugs

* [Enhancement] Add

* [Fix] Fix bugs

* [Fix] Fix bugs

* [Fix] Fix bugs

* [Fix] Fix bugs

* [Fix] Fix bugs

* [Docs] Update docstring

* [Docs] Update docstring

* [Refactor] Reformat with yapf

* [Docs] Update docstring

* [Refactor] Docformat

* [Refactor] Fix double-quote-string

* [Fix] fix pytorch version

* [Fix] isort

* [Fix] isort

* [Enhancement] Extend forward

* [Enhancement] Extend test

* [Fix] Fix targets

* [Enhancement] Extend LoRA to frozen models

* [Fix] Fix spelling

* [Fix] Override __getattr__

* [Fix] Add init_cfg

* [Enhancement] Add example config

* [Fix] Fix init_cfg

* [Enhancement] Add merging script

* [Fix] Remove init_cfg

* [Fix] Change lora key

* [Fix] Fix merge scripts

* [Fix] Fix merge scripts

* [Docs] Add docs

* [Fix] fix
pull/1554/merge
fanqiNO1 2023-07-24 11:30:57 +08:00 committed by GitHub
parent 569324b180
commit 64c446d507
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 520 additions and 0 deletions

View File

@ -0,0 +1,84 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_pil_resize.py',
'../_base_/schedules/imagenet_bs4096_AdamW.py',
'../_base_/default_runtime.py'
]
# model setting
model = dict(
type='ImageClassifier',
backbone=dict(
type='LoRAModel',
module=dict(
type='VisionTransformer',
arch='b',
img_size=384,
patch_size=16,
drop_rate=0.1,
init_cfg=dict(type='Pretrained', checkpoint='',
prefix='backbone')),
alpha=16,
rank=16,
drop_rate=0.1,
targets=[dict(type='qkv')]),
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=1000,
in_channels=768,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1,
mode='classy_vision'),
init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)],
))
# dataset setting
data_preprocessor = dict(
mean=[127.5, 127.5, 127.5],
std=[127.5, 127.5, 127.5],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=384, backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=384, edge='short', backend='pillow'),
dict(type='CenterCrop', crop_size=384),
dict(type='PackInputs'),
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=5,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=45,
by_epoch=True,
begin=5,
end=50,
eta_min=1e-6,
convert_to_iter_based=True)
]
train_cfg = dict(by_epoch=True, max_epochs=50)
default_hooks = dict(
# save checkpoint per epoch.
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
# schedule setting
optim_wrapper = dict(clip_grad=dict(max_norm=1.0))

View File

@ -15,6 +15,7 @@ The ``models`` package contains several sub-packages for addressing the differen
- :mod:`~mmpretrain.models.necks`: The component between backbones and heads, e.g., GlobalAveragePooling.
- :mod:`~mmpretrain.models.heads`: The component for specific tasks.
- :mod:`~mmpretrain.models.losses`: Loss functions.
- :mod:`~mmpretrain.models.peft`: The PEFT (Parameter-Efficient Fine-Tuning) module, e.g. LoRAModel.
- :mod:`~mmpretrain.models.utils`: Some helper functions and common components used in various networks.
- :mod:`~mmpretrain.models.utils.data_preprocessor`: The component before model to preprocess the inputs, e.g., ClsDataPreprocessor.
@ -306,6 +307,17 @@ Losses
SeesawLoss
SwAVLoss
.. module:: mmpretrain.models.peft
PEFT
------------------
.. autosummary::
:toctree: generated
:nosignatures:
LoRAModel
.. module:: mmpretrain.models.utils
models.utils

View File

@ -8,6 +8,7 @@ from .heads import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .multimodal import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .peft import * # noqa: F401,F403
from .retrievers import * # noqa: F401,F403
from .selfsup import * # noqa: F401,F403
from .tta import * # noqa: F401,F403

View File

@ -0,0 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .lora import LoRAModel
__all__ = [
'LoRAModel',
]

View File

@ -0,0 +1,205 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import re
from typing import Any, List
import torch
from mmengine.logging import print_log
from mmengine.model import BaseModule
from torch import nn
from mmpretrain.registry import MODELS
class LoRALinear(nn.Module):
r"""Implements LoRA in a linear layer.
Args:
original_layer (nn.Linear): The linear layer to be finetuned.
alpha (int): The scale factor of LoRA. Defaults to 1.
rank (int): The rank of LoRA. Defaults to 0.
drop_rate (float): The drop out rate for LoRA. Defaults to 0.
Note:
The forward process of LoRA linear layer is:
.. math::
`y = W_0 x + BAx * (\alpha / r)`
Where :math:`x` is the input, :math:`y` is the output,
:math:`W_0` is the parameter of the original layer,
:math:`A` and :math:`B` are the low-rank decomposition matrixs,
:math: `\alpha` is the scale factor and :math: `r` is the rank.
"""
def __init__(self,
original_layer: nn.Linear,
alpha: int = 1,
rank: int = 0,
drop_rate: float = 0.):
super(LoRALinear, self).__init__()
in_features = original_layer.in_features
out_features = original_layer.out_features
self.lora_dropout = nn.Dropout(drop_rate)
self.lora_down = nn.Linear(in_features, rank, bias=False)
self.lora_up = nn.Linear(rank, out_features, bias=False)
self.scaling = alpha / rank
nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_up.weight)
self.original_layer = original_layer
def forward(self, x: torch.Tensor):
out = self.original_layer(x)
lora_x = self.lora_dropout(x)
lora_out = self.lora_up(self.lora_down(lora_x)) * self.scaling
return out + lora_out
@MODELS.register_module()
class LoRAModel(BaseModule):
"""Implements LoRA in a module.
An PyTorch implement of : `LoRA: Low-Rank Adaptation
of Large Language Models <https://arxiv.org/abs/2106.09685>`_
Args:
module (dict): The config of the module to be finetuned. See
:mod:`mmpretrain.models`
alpha (int): The scale factor of LoRA. Defaults to 1.
rank (int): The rank of LoRA. Defaults to 0.
drop_rate (float): The drop out rate for LoRA. Defaults to 0.
targets (List[dict]): The target layers to be applied with the LoRA.
Defaults to a empty list. Specify by regular expression or suffix.
Examples:
>>> model = LoRAModel(
... module=dict(type='VisionTransformer', arch='b'),
... alpha=4,
... rank=4,
... drop_rate=0.1,
... targets=[
... dict(type='.*qkv'), # regular expression
... dict(type='proj', alpha=8, rank=8), # suffix
... ])
"""
def __init__(self,
module: dict,
alpha: int = 1,
rank: int = 0,
drop_rate: float = 0.,
targets: List[dict] = list()):
super().__init__()
module = MODELS.build(module)
module.init_weights()
self.module = module
self.alpha = alpha
self.rank = rank
self.drop_rate = drop_rate
assert len(targets) != 0, \
'The length of target layers should not be 0.'
self.targets = targets
self.applied = False
self.apply_lora()
if not self.applied:
raise ValueError(
'No lora layer is replaced. Please check targets.')
self._set_lora_trainable()
self._register_state_dict_hooks()
def apply_lora(self):
"""Apply LoRA to target layers."""
module_names = [k for k, _ in self.module.named_modules()]
for module_name in module_names:
for target in self.targets:
target_name = target['type']
target_alpha = target.get('alpha', self.alpha)
target_rank = target.get('rank', self.rank)
target_drop_rate = target.get('drop_rate', self.drop_rate)
if re.fullmatch(target_name, module_name) or \
module_name.endswith(target_name):
current_module = self.module.get_submodule(module_name)
if isinstance(current_module, nn.Linear):
print_log(
f'Set LoRA for {module_name} '
f'with alpha: {target_alpha}, '
f'rank: {target_rank}, '
f'drop rate: {target_drop_rate}',
logger='current')
self._replace_module(module_name, current_module,
target_alpha, target_rank,
target_drop_rate)
self.applied = True
def _replace_module(self, module_name: str, current_module: nn.Module,
alpha: int, rank: int, drop_rate: float):
"""Replace target layer with LoRA linear layer in place."""
parent_module_name = '.'.join(module_name.split('.')[:-1])
parent_module = self.module.get_submodule(parent_module_name)
target_name = module_name.split('.')[-1]
target_module = LoRALinear(current_module, alpha, rank, drop_rate)
setattr(parent_module, target_name, target_module)
def _set_lora_trainable(self):
"""Set only the lora parameters trainable."""
for name, param in self.named_parameters():
if '.lora_' in name:
param.requires_grad = True
else:
param.requires_grad = False
def _register_state_dict_hooks(self):
"""Register state dict hooks.
Register state dict saving hooks to save only the lora parameters to
the state dict. And register state dict loading hooks to handle the
incompatible keys while loading the state dict.
"""
def _state_dict_hook(module, state_dict, prefix, local_metadata):
"""Save only the lora parameters to the state dict."""
keys = [k for k, _ in state_dict.items()]
for key in keys:
if '.lora_' not in key:
state_dict.pop(key)
self._register_state_dict_hook(_state_dict_hook)
def _load_state_dict_post_hook(module, incompatible_keys):
"""Handle the incompatible keys while loading the state dict."""
missing_keys = incompatible_keys.missing_keys.copy()
for key in missing_keys:
if '.lora_' not in key:
incompatible_keys.missing_keys.remove(key)
unexpected_keys = incompatible_keys.unexpected_keys.copy()
for key in unexpected_keys:
if '.lora_' not in key:
incompatible_keys.unexpected_keys.remove(key)
self.register_load_state_dict_post_hook(_load_state_dict_post_hook)
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
def __getattr__(self, name: str) -> Any:
try:
return super(LoRAModel, self).__getattr__(name)
except AttributeError:
return self.module.__getattribute__(name)

View File

@ -0,0 +1,122 @@
# Copyright (c) OpenMMLab. All rights reserved.
import re
import pytest
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from mmpretrain.models.peft import LoRAModel
@pytest.mark.skipif(
digit_version(TORCH_VERSION) < digit_version('1.9.0'),
reason='get_submodule requires torch >= 1.9.0')
def test_lora_backbone():
module = dict(
type='VisionTransformer',
arch='base',
img_size=224,
patch_size=16,
drop_path_rate=0.1,
out_type='avg_featmap',
final_norm=False)
lora_cfg = dict(
module=module,
alpha=1,
rank=4,
drop_rate=0.1,
targets=[
dict(type='qkv'),
dict(type='.*proj', alpha=2, rank=2, drop_rate=0.2),
])
lora_model = LoRAModel(**lora_cfg)
# test replace module
for name, module in lora_model.named_modules():
if name.endswith('qkv'):
assert module.scaling == 0.25
if re.fullmatch('.*proj', name):
assert module.scaling == 1
# test freeze module
for name, param in lora_model.named_parameters():
if 'lora_' in name:
assert param.requires_grad
else:
assert not param.requires_grad
# test get state dict
state_dict = lora_model.state_dict()
assert len(state_dict) != 0
for name, param in state_dict.items():
assert 'lora_' in name
# test load state dict
incompatible_keys = lora_model.load_state_dict(state_dict, strict=True)
assert str(incompatible_keys) == '<All keys matched successfully>'
@pytest.mark.skipif(
digit_version(TORCH_VERSION) < digit_version('1.9.0'),
reason='get_submodule requires torch >= 1.9.0')
def test_lora_model():
module = dict(
type='MAE',
backbone=dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75),
neck=dict(
type='MAEPretrainDecoder',
patch_size=16,
in_chans=3,
embed_dim=768,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4.,
),
head=dict(
type='MAEPretrainHead',
norm_pix=True,
patch_size=16,
loss=dict(type='PixelReconstructionLoss', criterion='L2')),
init_cfg=[
dict(type='Xavier', layer='Linear', distribution='uniform'),
dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)
])
lora_cfg = dict(
module=module,
alpha=1,
rank=4,
drop_rate=0.1,
targets=[
dict(type='qkv'),
dict(type='.*proj', alpha=2, rank=2, drop_rate=0.2),
])
lora_model = LoRAModel(**lora_cfg)
# test replace module
for name, module in lora_model.named_modules():
if name.endswith('qkv'):
assert module.scaling == 0.25
if re.fullmatch('.*proj', name):
assert module.scaling == 1
# test freeze module
for name, param in lora_model.named_parameters():
if 'lora_' in name:
assert param.requires_grad
else:
assert not param.requires_grad
# test get state dict
state_dict = lora_model.state_dict()
assert len(state_dict) != 0
for name, param in state_dict.items():
assert 'lora_' in name
# test load state dict
incompatible_keys = lora_model.load_state_dict(state_dict, strict=True)
assert str(incompatible_keys) == '<All keys matched successfully>'

View File

@ -0,0 +1,90 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from pathlib import Path
import torch
from mmengine.config import Config
from mmpretrain.registry import MODELS
@torch.no_grad()
def merge_lora_weight(cfg, lora_weight):
"""Merge base weight and lora weight.
Args:
cfg (dict): config for LoRAModel.
lora_weight (dict): weight dict from LoRAModel.
Returns:
Merged weight.
"""
temp = dict()
mapping = dict()
for name, param in lora_weight['state_dict'].items():
# backbone.module.layers.11.attn.qkv.lora_down.weight
if '.lora_' in name:
lora_split = name.split('.')
prefix = '.'.join(lora_split[:-2])
if prefix not in mapping:
mapping[prefix] = dict()
lora_type = lora_split[-2]
mapping[prefix][lora_type] = param
else:
temp[name] = param
model = MODELS.build(cfg['model'])
for name, param in model.named_parameters():
if name in temp or '.lora_' in name:
continue
else:
name_split = name.split('.')
prefix = prefix = '.'.join(name_split[:-2])
if prefix in mapping:
name_split.pop(-2)
if name_split[-1] == 'weight':
scaling = get_scaling(model, prefix)
lora_down = mapping[prefix]['lora_down']
lora_up = mapping[prefix]['lora_up']
param += lora_up @ lora_down * scaling
name_split.pop(1)
name = '.'.join(name_split)
temp[name] = param
result = dict()
result['state_dict'] = temp
result['meta'] = lora_weight['meta']
return result
def get_scaling(model, prefix):
"""Get the scaling of target layer.
Args:
model (LoRAModel): the LoRAModel.
prefix (str): the prefix of the layer.
Returns:
the scale of the LoRALinear.
"""
prefix_split = prefix.split('.')
for i in prefix_split:
model = getattr(model, i)
return model.scaling
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Merge LoRA weight')
parser.add_argument('cfg', help='cfg path')
parser.add_argument('src', help='src lora model path')
parser.add_argument('dst', help='save path')
args = parser.parse_args()
dst = Path(args.dst)
if dst.suffix != '.pth':
print('The path should contain the name of the pth format file.')
exit(1)
dst.parent.mkdir(parents=True, exist_ok=True)
cfg = Config.fromfile(args.cfg)
lora_model = torch.load(args.src, map_location='cpu')
merged_model = merge_lora_weight(cfg, lora_model)
torch.save(merged_model, args.dst)