[Fix] Add init_cfg
parent
c5525a3049
commit
affb2406a6
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import re
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
from mmengine.logging import print_log
|
||||
|
@ -75,6 +75,8 @@ class LoRAModel(BaseModule):
|
|||
drop_rate (float): The drop out rate for LoRA. Defaults to 0.
|
||||
targets (List[dict]): The target layers to be applied with the LoRA.
|
||||
Defaults to a empty list.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> model = LoRAModel(
|
||||
|
@ -93,9 +95,10 @@ class LoRAModel(BaseModule):
|
|||
alpha: int = 1,
|
||||
rank: int = 0,
|
||||
drop_rate: float = 0.,
|
||||
targets: List[dict] = list()):
|
||||
targets: List[dict] = list(),
|
||||
init_cfg: Optional[dict] = None):
|
||||
|
||||
super().__init__()
|
||||
super().__init__(init_cfg)
|
||||
|
||||
module = MODELS.build(module)
|
||||
module.init_weights()
|
||||
|
|
Loading…
Reference in New Issue