mmpretrain/tests/test_models/test_peft/test_lora.py

123 lines
3.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import re
import pytest
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from mmpretrain.models.peft import LoRAModel
@pytest.mark.skipif(
digit_version(TORCH_VERSION) < digit_version('1.9.0'),
reason='get_submodule requires torch >= 1.9.0')
def test_lora_backbone():
module = dict(
type='VisionTransformer',
arch='base',
img_size=224,
patch_size=16,
drop_path_rate=0.1,
out_type='avg_featmap',
final_norm=False)
lora_cfg = dict(
module=module,
alpha=1,
rank=4,
drop_rate=0.1,
targets=[
dict(type='qkv'),
dict(type='.*proj', alpha=2, rank=2, drop_rate=0.2),
])
lora_model = LoRAModel(**lora_cfg)
# test replace module
for name, module in lora_model.named_modules():
if name.endswith('qkv'):
assert module.scaling == 0.25
if re.fullmatch('.*proj', name):
assert module.scaling == 1
# test freeze module
for name, param in lora_model.named_parameters():
if 'lora_' in name:
assert param.requires_grad
else:
assert not param.requires_grad
# test get state dict
state_dict = lora_model.state_dict()
assert len(state_dict) != 0
for name, param in state_dict.items():
assert 'lora_' in name
# test load state dict
incompatible_keys = lora_model.load_state_dict(state_dict, strict=True)
assert str(incompatible_keys) == '<All keys matched successfully>'
@pytest.mark.skipif(
digit_version(TORCH_VERSION) < digit_version('1.9.0'),
reason='get_submodule requires torch >= 1.9.0')
def test_lora_model():
module = dict(
type='MAE',
backbone=dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75),
neck=dict(
type='MAEPretrainDecoder',
patch_size=16,
in_chans=3,
embed_dim=768,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4.,
),
head=dict(
type='MAEPretrainHead',
norm_pix=True,
patch_size=16,
loss=dict(type='PixelReconstructionLoss', criterion='L2')),
init_cfg=[
dict(type='Xavier', layer='Linear', distribution='uniform'),
dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)
])
lora_cfg = dict(
module=module,
alpha=1,
rank=4,
drop_rate=0.1,
targets=[
dict(type='qkv'),
dict(type='.*proj', alpha=2, rank=2, drop_rate=0.2),
])
lora_model = LoRAModel(**lora_cfg)
# test replace module
for name, module in lora_model.named_modules():
if name.endswith('qkv'):
assert module.scaling == 0.25
if re.fullmatch('.*proj', name):
assert module.scaling == 1
# test freeze module
for name, param in lora_model.named_parameters():
if 'lora_' in name:
assert param.requires_grad
else:
assert not param.requires_grad
# test get state dict
state_dict = lora_model.state_dict()
assert len(state_dict) != 0
for name, param in state_dict.items():
assert 'lora_' in name
# test load state dict
incompatible_keys = lora_model.load_state_dict(state_dict, strict=True)
assert str(incompatible_keys) == '<All keys matched successfully>'