diff --git a/mmpretrain/models/multimodal/minigpt4/minigpt4.py b/mmpretrain/models/multimodal/minigpt4/minigpt4.py index 4bbd5aaa..5e101648 100644 --- a/mmpretrain/models/multimodal/minigpt4/minigpt4.py +++ b/mmpretrain/models/multimodal/minigpt4/minigpt4.py @@ -155,8 +155,8 @@ class MiniGPT4(BaseModel): top_p=0.9, repetition_penalty=1.0, length_penalty=1.0, - temperature=1.0, - **generation_cfg) + temperature=1.0) + self.generation_cfg.update(**generation_cfg) if hasattr(self, 'register_load_state_dict_post_hook'): self.register_load_state_dict_post_hook(self._load_llama_proj_hook)