diff --git a/mmpretrain/models/peft/lora.py b/mmpretrain/models/peft/lora.py index 5ace6772..ae1bae7f 100644 --- a/mmpretrain/models/peft/lora.py +++ b/mmpretrain/models/peft/lora.py @@ -176,7 +176,7 @@ class LoRAModel(BaseModule): """Save only the lora parameters to the state dict.""" keys = [k for k, _ in state_dict.items()] for key in keys: - if 'lora_' not in key: + if '.lora_' not in key: state_dict.pop(key) self._register_state_dict_hook(_state_dict_hook) @@ -185,12 +185,12 @@ class LoRAModel(BaseModule): """Handle the incompatible keys while loading the state dict.""" missing_keys = incompatible_keys.missing_keys.copy() for key in missing_keys: - if 'lora_' not in key: + if '.lora_' not in key: incompatible_keys.missing_keys.remove(key) unexpected_keys = incompatible_keys.unexpected_keys.copy() for key in unexpected_keys: - if 'lora_' not in key: + if '.lora_' not in key: incompatible_keys.unexpected_keys.remove(key) self.register_load_state_dict_post_hook(_load_state_dict_post_hook)