From 19c177ca935b2d327ca58dab9f3e1f1ecc4f5af4 Mon Sep 17 00:00:00 2001 From: GinkREAL Date: Sat, 25 Jan 2025 00:19:34 +0800 Subject: [PATCH] Fix prompt stacking in bedrock converse (#17613) --- .../llama_index/llms/bedrock_converse/base.py | 16 ++++------------ .../pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py index 230585eac35a2..b64aeec18229f 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py +++ b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py @@ -302,15 +302,13 @@ def _get_content_and_tool_calls( def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: # convert Llama Index messages to AWS Bedrock Converse messages converse_messages, system_prompt = messages_to_converse_messages(messages) - if len(system_prompt) > 0 or self.system_prompt is None: - self.system_prompt = system_prompt all_kwargs = self._get_all_kwargs(**kwargs) # invoke LLM in AWS Bedrock Converse with retry response = converse_with_retry( client=self._client, messages=converse_messages, - system_prompt=self.system_prompt, + system_prompt=system_prompt, max_retries=self.max_retries, stream=False, guardrail_identifier=self.guardrail_identifier, @@ -349,15 +347,13 @@ def stream_chat( ) -> ChatResponseGen: # convert Llama Index messages to AWS Bedrock Converse messages converse_messages, system_prompt = messages_to_converse_messages(messages) - if len(system_prompt) > 0 or self.system_prompt is None: - self.system_prompt = system_prompt all_kwargs = self._get_all_kwargs(**kwargs) # invoke LLM in AWS Bedrock Converse with retry response = converse_with_retry( client=self._client, messages=converse_messages, - system_prompt=self.system_prompt, + system_prompt=system_prompt, max_retries=self.max_retries, stream=True, guardrail_identifier=self.guardrail_identifier, @@ -431,8 +427,6 @@ async def achat( ) -> ChatResponse: # convert Llama Index messages to AWS Bedrock Converse messages converse_messages, system_prompt = messages_to_converse_messages(messages) - if len(system_prompt) > 0 or self.system_prompt is None: - self.system_prompt = system_prompt all_kwargs = self._get_all_kwargs(**kwargs) # invoke LLM in AWS Bedrock Converse with retry @@ -440,7 +434,7 @@ async def achat( session=self._asession, config=self._config, messages=converse_messages, - system_prompt=self.system_prompt, + system_prompt=system_prompt, max_retries=self.max_retries, stream=False, guardrail_identifier=self.guardrail_identifier, @@ -479,8 +473,6 @@ async def astream_chat( ) -> ChatResponseAsyncGen: # convert Llama Index messages to AWS Bedrock Converse messages converse_messages, system_prompt = messages_to_converse_messages(messages) - if len(system_prompt) > 0 or self.system_prompt is None: - self.system_prompt = system_prompt all_kwargs = self._get_all_kwargs(**kwargs) # invoke LLM in AWS Bedrock Converse with retry @@ -488,7 +480,7 @@ async def astream_chat( session=self._asession, config=self._config, messages=converse_messages, - system_prompt=self.system_prompt, + system_prompt=system_prompt, max_retries=self.max_retries, stream=True, guardrail_identifier=self.guardrail_identifier, diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/pyproject.toml index 4b9d6b6f07bcf..0c0316856d9ba 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-bedrock-converse" readme = "README.md" -version = "0.4.3" +version = "0.4.4" [tool.poetry.dependencies] python = ">=3.9,<4.0"