Skip to content

Commit

Permalink
Python: Update sort step method for assistant invoke. (#10191)
Browse files Browse the repository at this point in the history
### Motivation and Context

At times, when the (Azure) OpenAI Assistant Agent makes a tool call,
that tool call's creation timestamp is coming after the message creation
timestamp (the message creation being the text that the assistant
responds with). Currently in our code, if we have a tool call
(`FunctionCallContent`), we first yield that, and then we make a call to
get completed steps, to then yield more content like
`FunctionResultContent` and `TextContent`. There will be two steps (or
more, depending upon the number of tool calls).

Right now we sort the completed steps in this way:

```python
completed_steps_to_process: list[RunStep] = sorted(
    [s for s in steps if s.completed_at is not None and s.id not in processed_step_ids],
    key=lambda s: s.created_at,
)
```
When there are no failures, it's because the tool call was created
before the final message content (as has been since this assistant was
first coded). However, it appears that processing on the server-side can
cause fluctuations in when the steps are created/processed. When we have
a failure, we now have the message_creation (`TextContent`) yielded
before the `FunctionResultContent`, which if sent to an (Azure) OpenAI
Chat Completion endpoint will break with a 400 because of the gap in the
ordering between the `FunctionCallContent` and the
`FunctionResultContent`:

```
FunctionCallContent
TextContent # this should follow `FunctionResultContent` (and it does during times when we don't see a 400)
FunctionResultContent
```

The 400 error isn't 100% repeatable because of server-side processing,
so we will get the correct ordering:

```
FunctionCallContent
FunctionResultContent
TextContent
```

This PR updates the ordering of how we sort of the completed steps -- if
we have a step_type == "tool_calls" and "message_creation", we sort to
allow for "tool_calls" to come first, and any ties are broken by the
step.completed_at timestamp. With this update, we no longer receive 400s
because the ordering of the content types is correct when sending to an
OpenAI model.

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

### Description

This PR:
- Updates the way we sort completed messages during an assistant invoke.
- Adds logging for the assistant invoke method at the `info` level.
- Closes #10141

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [X] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [X] I didn't break anyone 😄

---------

Co-authored-by: Eduard van Valkenburg <eavanvalkenburg@users.noreply.github.com>
  • Loading branch information
moonbox3 and eavanvalkenburg authored Jan 20, 2025
1 parent 5b65e6a commit 71e5040
Show file tree
Hide file tree
Showing 3 changed files with 1,134 additions and 1,055 deletions.
49 changes: 46 additions & 3 deletions python/semantic_kernel/agents/open_ai/open_ai_assistant_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,8 @@ async def _invoke_internal(
# Filter out None values to avoid passing them as kwargs
run_options = {k: v for k, v in run_options.items() if v is not None}

logger.debug(f"Starting invoke for agent `{self.name}` and thread `{thread_id}`")

run = await self.client.beta.threads.runs.create(
assistant_id=self.assistant.id,
thread_id=thread_id,
Expand All @@ -755,8 +757,13 @@ async def _invoke_internal(

# Check if function calling required
if run.status == "requires_action":
logger.debug(f"Run [{run.id}] requires action for agent `{self.name}` and thread `{thread_id}`")
fccs = get_function_call_contents(run, function_steps)
if fccs:
logger.debug(
f"Yielding `generate_function_call_content` for agent `{self.name}` and "
f"thread `{thread_id}`, visibility False"
)
yield False, generate_function_call_content(agent_name=self.name, fccs=fccs)

from semantic_kernel.contents.chat_history import ChatHistory
Expand All @@ -770,28 +777,52 @@ async def _invoke_internal(
thread_id=thread_id,
tool_outputs=tool_outputs, # type: ignore
)
logger.debug(f"Submitted tool outputs for agent `{self.name}` and thread `{thread_id}`")

steps_response = await self.client.beta.threads.runs.steps.list(run_id=run.id, thread_id=thread_id)
logger.debug(f"Called for steps_response for run [{run.id}] agent `{self.name}` and thread `{thread_id}`")
steps: list[RunStep] = steps_response.data
completed_steps_to_process: list[RunStep] = sorted(
[s for s in steps if s.completed_at is not None and s.id not in processed_step_ids],
key=lambda s: s.created_at,

def sort_key(step: RunStep):
# Put tool_calls first, then message_creation
# If multiple steps share a type, break ties by completed_at
return (0 if step.type == "tool_calls" else 1, step.completed_at)

completed_steps_to_process = sorted(
[s for s in steps if s.completed_at is not None and s.id not in processed_step_ids], key=sort_key
)

logger.debug(
f"Completed steps to process for run [{run.id}] agent `{self.name}` and thread `{thread_id}` "
f"with length `{len(completed_steps_to_process)}`"
)

message_count = 0
for completed_step in completed_steps_to_process:
if completed_step.type == "tool_calls":
logger.debug(
f"Entering step type tool_calls for run [{run.id}], agent `{self.name}` and "
f"thread `{thread_id}`"
)
assert hasattr(completed_step.step_details, "tool_calls") # nosec
for tool_call in completed_step.step_details.tool_calls:
is_visible = False
content: "ChatMessageContent | None" = None
if tool_call.type == "code_interpreter":
logger.debug(
f"Entering step type tool_calls for run [{run.id}], [code_interpreter] for "
f"agent `{self.name}` and thread `{thread_id}`"
)
content = generate_code_interpreter_content(
self.name,
tool_call.code_interpreter.input, # type: ignore
)
is_visible = True
elif tool_call.type == "function":
logger.debug(
f"Entering step type tool_calls for run [{run.id}], [function] for agent `{self.name}` "
f"and thread `{thread_id}`"
)
function_step = function_steps.get(tool_call.id)
assert function_step is not None # nosec
content = generate_function_result_content(
Expand All @@ -800,8 +831,16 @@ async def _invoke_internal(

if content:
message_count += 1
logger.debug(
f"Yielding tool_message for run [{run.id}], agent `{self.name}` and thread "
f"`{thread_id}` and message count `{message_count}`, is_visible `{is_visible}`"
)
yield is_visible, content
elif completed_step.type == "message_creation":
logger.debug(
f"Entering step type message_creation for run [{run.id}], agent `{self.name}` and "
f"thread `{thread_id}`"
)
message = await self._retrieve_message(
thread_id=thread_id,
message_id=completed_step.step_details.message_creation.message_id, # type: ignore
Expand All @@ -810,6 +849,10 @@ async def _invoke_internal(
content = generate_message_content(self.name, message)
if content and len(content.items) > 0:
message_count += 1
logger.debug(
f"Yielding message_creation for run [{run.id}], agent `{self.name}` and "
f"thread `{thread_id}` and message count `{message_count}`, is_visible `{True}`"
)
yield True, content
processed_step_ids.add(completed_step.id)

Expand Down
87 changes: 87 additions & 0 deletions python/tests/unit/agents/test_open_ai_assistant_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,35 @@ def __init__(self):
)


@pytest.fixture
def mock_run_step_function_tool_call():
class MockToolCall:
def __init__(self):
self.type = "function"

return RunStep(
id="step_id_1",
type="tool_calls",
completed_at=int(datetime.now(timezone.utc).timestamp()),
created_at=int((datetime.now(timezone.utc) - timedelta(minutes=1)).timestamp()),
step_details=ToolCallsStepDetails(
tool_calls=[
FunctionToolCall(
type="function",
id="tool_call_id",
function=RunsFunction(arguments="{}", name="function_name", outpt="test output"),
),
],
type="tool_calls",
),
assistant_id="assistant_id",
object="thread.run.step",
run_id="run_id",
status="completed",
thread_id="thread_id",
)


@pytest.fixture
def mock_run_step_message_creation():
class MockMessageCreation:
Expand Down Expand Up @@ -1206,6 +1235,64 @@ def mock_get_function_call_contents(run, function_steps):
_ = [message async for message in azure_openai_assistant_agent.invoke("thread_id")]


async def test_invoke_order(
azure_openai_assistant_agent,
mock_assistant,
mock_run_required_action,
mock_run_step_function_tool_call,
mock_run_step_message_creation,
mock_thread_messages,
mock_function_call_content,
):
poll_count = 0

async def mock_poll_run_status(run, thread_id):
nonlocal poll_count
if run.status == "requires_action":
if poll_count == 0:
pass
else:
run.status = "completed"
poll_count += 1
return run

def mock_get_function_call_contents(run, function_steps):
function_call_content = mock_function_call_content
function_call_content.id = "tool_call_id"
function_steps[function_call_content.id] = function_call_content
return [function_call_content]

azure_openai_assistant_agent.assistant = mock_assistant
azure_openai_assistant_agent._poll_run_status = AsyncMock(side_effect=mock_poll_run_status)
azure_openai_assistant_agent._retrieve_message = AsyncMock(return_value=mock_thread_messages[0])

with patch(
"semantic_kernel.agents.open_ai.assistant_content_generation.get_function_call_contents",
side_effect=mock_get_function_call_contents,
):
client = azure_openai_assistant_agent.client

with patch.object(client.beta.threads.runs, "create", new_callable=AsyncMock) as mock_runs_create:
mock_runs_create.return_value = mock_run_required_action

with (
patch.object(client.beta.threads.runs, "submit_tool_outputs", new_callable=AsyncMock),
patch.object(client.beta.threads.runs.steps, "list", new_callable=AsyncMock) as mock_steps_list,
):
mock_steps_list.return_value = MagicMock(
data=[mock_run_step_message_creation, mock_run_step_function_tool_call]
)

messages = []
async for _, content in azure_openai_assistant_agent._invoke_internal("thread_id"):
messages.append(content)

assert len(messages) == 3
assert isinstance(messages[0].items[0], FunctionCallContent)
assert isinstance(messages[1].items[0], FunctionResultContent)
assert isinstance(messages[2].items[0], TextContent)


async def test_invoke_stream(
azure_openai_assistant_agent,
mock_assistant,
Expand Down
Loading

0 comments on commit 71e5040

Please sign in to comment.