diff --git a/configs/vision_transformer/vit-base-p16_8xb64-lora_in1k-384px.py b/configs/vision_transformer/vit-base-p16_8xb64-lora_in1k-384px.py new file mode 100644 index 00000000..ffe1018e --- /dev/null +++ b/configs/vision_transformer/vit-base-p16_8xb64-lora_in1k-384px.py @@ -0,0 +1,84 @@ +_base_ = [ + '../_base_/datasets/imagenet_bs64_pil_resize.py', + '../_base_/schedules/imagenet_bs4096_AdamW.py', + '../_base_/default_runtime.py' +] + +# model setting +model = dict( + type='ImageClassifier', + backbone=dict( + type='LoRAModel', + module=dict( + type='VisionTransformer', + arch='b', + img_size=384, + patch_size=16, + drop_rate=0.1, + init_cfg=dict(type='Pretrained', checkpoint='', + prefix='backbone')), + alpha=16, + rank=16, + drop_rate=0.1, + targets=[dict(type='qkv')]), + neck=None, + head=dict( + type='VisionTransformerClsHead', + num_classes=1000, + in_channels=768, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, + mode='classy_vision'), + init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)], + )) + +# dataset setting +data_preprocessor = dict( + mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='RandomResizedCrop', scale=384, backend='pillow'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='PackInputs'), +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='ResizeEdge', scale=384, edge='short', backend='pillow'), + dict(type='CenterCrop', crop_size=384), + dict(type='PackInputs'), +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) + +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=5, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=45, + by_epoch=True, + begin=5, + end=50, + eta_min=1e-6, + convert_to_iter_based=True) +] + +train_cfg = dict(by_epoch=True, max_epochs=50) +default_hooks = dict( + # save checkpoint per epoch. + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3)) + +# schedule setting +optim_wrapper = dict(clip_grad=dict(max_norm=1.0)) diff --git a/docs/en/api/models.rst b/docs/en/api/models.rst index 93e3e841..30980324 100644 --- a/docs/en/api/models.rst +++ b/docs/en/api/models.rst @@ -15,6 +15,7 @@ The ``models`` package contains several sub-packages for addressing the differen - :mod:`~mmpretrain.models.necks`: The component between backbones and heads, e.g., GlobalAveragePooling. - :mod:`~mmpretrain.models.heads`: The component for specific tasks. - :mod:`~mmpretrain.models.losses`: Loss functions. +- :mod:`~mmpretrain.models.peft`: The PEFT (Parameter-Efficient Fine-Tuning) module, e.g. LoRAModel. - :mod:`~mmpretrain.models.utils`: Some helper functions and common components used in various networks. - :mod:`~mmpretrain.models.utils.data_preprocessor`: The component before model to preprocess the inputs, e.g., ClsDataPreprocessor. @@ -306,6 +307,17 @@ Losses SeesawLoss SwAVLoss +.. module:: mmpretrain.models.peft + +PEFT +------------------ + +.. autosummary:: + :toctree: generated + :nosignatures: + + LoRAModel + .. module:: mmpretrain.models.utils models.utils diff --git a/mmpretrain/models/__init__.py b/mmpretrain/models/__init__.py index ba05735b..3f583114 100644 --- a/mmpretrain/models/__init__.py +++ b/mmpretrain/models/__init__.py @@ -8,6 +8,7 @@ from .heads import * # noqa: F401,F403 from .losses import * # noqa: F401,F403 from .multimodal import * # noqa: F401,F403 from .necks import * # noqa: F401,F403 +from .peft import * # noqa: F401,F403 from .retrievers import * # noqa: F401,F403 from .selfsup import * # noqa: F401,F403 from .tta import * # noqa: F401,F403 diff --git a/mmpretrain/models/peft/__init__.py b/mmpretrain/models/peft/__init__.py new file mode 100644 index 00000000..9f43e148 --- /dev/null +++ b/mmpretrain/models/peft/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .lora import LoRAModel + +__all__ = [ + 'LoRAModel', +] diff --git a/mmpretrain/models/peft/lora.py b/mmpretrain/models/peft/lora.py new file mode 100644 index 00000000..ae1bae7f --- /dev/null +++ b/mmpretrain/models/peft/lora.py @@ -0,0 +1,205 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import re +from typing import Any, List + +import torch +from mmengine.logging import print_log +from mmengine.model import BaseModule +from torch import nn + +from mmpretrain.registry import MODELS + + +class LoRALinear(nn.Module): + r"""Implements LoRA in a linear layer. + + Args: + original_layer (nn.Linear): The linear layer to be finetuned. + alpha (int): The scale factor of LoRA. Defaults to 1. + rank (int): The rank of LoRA. Defaults to 0. + drop_rate (float): The drop out rate for LoRA. Defaults to 0. + + Note: + The forward process of LoRA linear layer is: + + .. math:: + `y = W_0 x + BAx * (\alpha / r)` + + Where :math:`x` is the input, :math:`y` is the output, + :math:`W_0` is the parameter of the original layer, + :math:`A` and :math:`B` are the low-rank decomposition matrixs, + :math: `\alpha` is the scale factor and :math: `r` is the rank. + """ + + def __init__(self, + original_layer: nn.Linear, + alpha: int = 1, + rank: int = 0, + drop_rate: float = 0.): + super(LoRALinear, self).__init__() + in_features = original_layer.in_features + out_features = original_layer.out_features + + self.lora_dropout = nn.Dropout(drop_rate) + self.lora_down = nn.Linear(in_features, rank, bias=False) + self.lora_up = nn.Linear(rank, out_features, bias=False) + self.scaling = alpha / rank + + nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_up.weight) + + self.original_layer = original_layer + + def forward(self, x: torch.Tensor): + out = self.original_layer(x) + + lora_x = self.lora_dropout(x) + lora_out = self.lora_up(self.lora_down(lora_x)) * self.scaling + + return out + lora_out + + +@MODELS.register_module() +class LoRAModel(BaseModule): + """Implements LoRA in a module. + + An PyTorch implement of : `LoRA: Low-Rank Adaptation + of Large Language Models `_ + + Args: + module (dict): The config of the module to be finetuned. See + :mod:`mmpretrain.models` + alpha (int): The scale factor of LoRA. Defaults to 1. + rank (int): The rank of LoRA. Defaults to 0. + drop_rate (float): The drop out rate for LoRA. Defaults to 0. + targets (List[dict]): The target layers to be applied with the LoRA. + Defaults to a empty list. Specify by regular expression or suffix. + + Examples: + >>> model = LoRAModel( + ... module=dict(type='VisionTransformer', arch='b'), + ... alpha=4, + ... rank=4, + ... drop_rate=0.1, + ... targets=[ + ... dict(type='.*qkv'), # regular expression + ... dict(type='proj', alpha=8, rank=8), # suffix + ... ]) + """ + + def __init__(self, + module: dict, + alpha: int = 1, + rank: int = 0, + drop_rate: float = 0., + targets: List[dict] = list()): + + super().__init__() + + module = MODELS.build(module) + module.init_weights() + + self.module = module + self.alpha = alpha + self.rank = rank + self.drop_rate = drop_rate + + assert len(targets) != 0, \ + 'The length of target layers should not be 0.' + + self.targets = targets + + self.applied = False + self.apply_lora() + + if not self.applied: + raise ValueError( + 'No lora layer is replaced. Please check targets.') + + self._set_lora_trainable() + self._register_state_dict_hooks() + + def apply_lora(self): + """Apply LoRA to target layers.""" + module_names = [k for k, _ in self.module.named_modules()] + for module_name in module_names: + for target in self.targets: + target_name = target['type'] + target_alpha = target.get('alpha', self.alpha) + target_rank = target.get('rank', self.rank) + target_drop_rate = target.get('drop_rate', self.drop_rate) + + if re.fullmatch(target_name, module_name) or \ + module_name.endswith(target_name): + current_module = self.module.get_submodule(module_name) + if isinstance(current_module, nn.Linear): + print_log( + f'Set LoRA for {module_name} ' + f'with alpha: {target_alpha}, ' + f'rank: {target_rank}, ' + f'drop rate: {target_drop_rate}', + logger='current') + + self._replace_module(module_name, current_module, + target_alpha, target_rank, + target_drop_rate) + self.applied = True + + def _replace_module(self, module_name: str, current_module: nn.Module, + alpha: int, rank: int, drop_rate: float): + """Replace target layer with LoRA linear layer in place.""" + parent_module_name = '.'.join(module_name.split('.')[:-1]) + parent_module = self.module.get_submodule(parent_module_name) + + target_name = module_name.split('.')[-1] + target_module = LoRALinear(current_module, alpha, rank, drop_rate) + setattr(parent_module, target_name, target_module) + + def _set_lora_trainable(self): + """Set only the lora parameters trainable.""" + for name, param in self.named_parameters(): + if '.lora_' in name: + param.requires_grad = True + else: + param.requires_grad = False + + def _register_state_dict_hooks(self): + """Register state dict hooks. + + Register state dict saving hooks to save only the lora parameters to + the state dict. And register state dict loading hooks to handle the + incompatible keys while loading the state dict. + """ + + def _state_dict_hook(module, state_dict, prefix, local_metadata): + """Save only the lora parameters to the state dict.""" + keys = [k for k, _ in state_dict.items()] + for key in keys: + if '.lora_' not in key: + state_dict.pop(key) + + self._register_state_dict_hook(_state_dict_hook) + + def _load_state_dict_post_hook(module, incompatible_keys): + """Handle the incompatible keys while loading the state dict.""" + missing_keys = incompatible_keys.missing_keys.copy() + for key in missing_keys: + if '.lora_' not in key: + incompatible_keys.missing_keys.remove(key) + + unexpected_keys = incompatible_keys.unexpected_keys.copy() + for key in unexpected_keys: + if '.lora_' not in key: + incompatible_keys.unexpected_keys.remove(key) + + self.register_load_state_dict_post_hook(_load_state_dict_post_hook) + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + + def __getattr__(self, name: str) -> Any: + try: + return super(LoRAModel, self).__getattr__(name) + except AttributeError: + return self.module.__getattribute__(name) diff --git a/tests/test_models/test_peft/test_lora.py b/tests/test_models/test_peft/test_lora.py new file mode 100644 index 00000000..d1485381 --- /dev/null +++ b/tests/test_models/test_peft/test_lora.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re + +import pytest +from mmengine.utils import digit_version +from mmengine.utils.dl_utils import TORCH_VERSION + +from mmpretrain.models.peft import LoRAModel + + +@pytest.mark.skipif( + digit_version(TORCH_VERSION) < digit_version('1.9.0'), + reason='get_submodule requires torch >= 1.9.0') +def test_lora_backbone(): + module = dict( + type='VisionTransformer', + arch='base', + img_size=224, + patch_size=16, + drop_path_rate=0.1, + out_type='avg_featmap', + final_norm=False) + + lora_cfg = dict( + module=module, + alpha=1, + rank=4, + drop_rate=0.1, + targets=[ + dict(type='qkv'), + dict(type='.*proj', alpha=2, rank=2, drop_rate=0.2), + ]) + + lora_model = LoRAModel(**lora_cfg) + + # test replace module + for name, module in lora_model.named_modules(): + if name.endswith('qkv'): + assert module.scaling == 0.25 + if re.fullmatch('.*proj', name): + assert module.scaling == 1 + + # test freeze module + for name, param in lora_model.named_parameters(): + if 'lora_' in name: + assert param.requires_grad + else: + assert not param.requires_grad + + # test get state dict + state_dict = lora_model.state_dict() + assert len(state_dict) != 0 + for name, param in state_dict.items(): + assert 'lora_' in name + + # test load state dict + incompatible_keys = lora_model.load_state_dict(state_dict, strict=True) + assert str(incompatible_keys) == '' + + +@pytest.mark.skipif( + digit_version(TORCH_VERSION) < digit_version('1.9.0'), + reason='get_submodule requires torch >= 1.9.0') +def test_lora_model(): + module = dict( + type='MAE', + backbone=dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75), + neck=dict( + type='MAEPretrainDecoder', + patch_size=16, + in_chans=3, + embed_dim=768, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4., + ), + head=dict( + type='MAEPretrainHead', + norm_pix=True, + patch_size=16, + loss=dict(type='PixelReconstructionLoss', criterion='L2')), + init_cfg=[ + dict(type='Xavier', layer='Linear', distribution='uniform'), + dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0) + ]) + + lora_cfg = dict( + module=module, + alpha=1, + rank=4, + drop_rate=0.1, + targets=[ + dict(type='qkv'), + dict(type='.*proj', alpha=2, rank=2, drop_rate=0.2), + ]) + + lora_model = LoRAModel(**lora_cfg) + + # test replace module + for name, module in lora_model.named_modules(): + if name.endswith('qkv'): + assert module.scaling == 0.25 + if re.fullmatch('.*proj', name): + assert module.scaling == 1 + + # test freeze module + for name, param in lora_model.named_parameters(): + if 'lora_' in name: + assert param.requires_grad + else: + assert not param.requires_grad + + # test get state dict + state_dict = lora_model.state_dict() + assert len(state_dict) != 0 + for name, param in state_dict.items(): + assert 'lora_' in name + + # test load state dict + incompatible_keys = lora_model.load_state_dict(state_dict, strict=True) + assert str(incompatible_keys) == '' diff --git a/tools/model_converters/merge_lora_weight.py b/tools/model_converters/merge_lora_weight.py new file mode 100644 index 00000000..fc51f9f2 --- /dev/null +++ b/tools/model_converters/merge_lora_weight.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +from pathlib import Path + +import torch +from mmengine.config import Config + +from mmpretrain.registry import MODELS + + +@torch.no_grad() +def merge_lora_weight(cfg, lora_weight): + """Merge base weight and lora weight. + + Args: + cfg (dict): config for LoRAModel. + lora_weight (dict): weight dict from LoRAModel. + Returns: + Merged weight. + """ + temp = dict() + mapping = dict() + for name, param in lora_weight['state_dict'].items(): + # backbone.module.layers.11.attn.qkv.lora_down.weight + if '.lora_' in name: + lora_split = name.split('.') + prefix = '.'.join(lora_split[:-2]) + if prefix not in mapping: + mapping[prefix] = dict() + lora_type = lora_split[-2] + mapping[prefix][lora_type] = param + else: + temp[name] = param + + model = MODELS.build(cfg['model']) + for name, param in model.named_parameters(): + if name in temp or '.lora_' in name: + continue + else: + name_split = name.split('.') + prefix = prefix = '.'.join(name_split[:-2]) + if prefix in mapping: + name_split.pop(-2) + if name_split[-1] == 'weight': + scaling = get_scaling(model, prefix) + lora_down = mapping[prefix]['lora_down'] + lora_up = mapping[prefix]['lora_up'] + param += lora_up @ lora_down * scaling + name_split.pop(1) + name = '.'.join(name_split) + temp[name] = param + + result = dict() + result['state_dict'] = temp + result['meta'] = lora_weight['meta'] + return result + + +def get_scaling(model, prefix): + """Get the scaling of target layer. + + Args: + model (LoRAModel): the LoRAModel. + prefix (str): the prefix of the layer. + Returns: + the scale of the LoRALinear. + """ + prefix_split = prefix.split('.') + for i in prefix_split: + model = getattr(model, i) + return model.scaling + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Merge LoRA weight') + parser.add_argument('cfg', help='cfg path') + parser.add_argument('src', help='src lora model path') + parser.add_argument('dst', help='save path') + args = parser.parse_args() + dst = Path(args.dst) + if dst.suffix != '.pth': + print('The path should contain the name of the pth format file.') + exit(1) + dst.parent.mkdir(parents=True, exist_ok=True) + + cfg = Config.fromfile(args.cfg) + lora_model = torch.load(args.src, map_location='cpu') + + merged_model = merge_lora_weight(cfg, lora_model) + torch.save(merged_model, args.dst)