Skip to content

Commit

Permalink
fix: Update other tests
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjcsmith committed Dec 23, 2024
1 parent 615ef87 commit e8593c5
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 39 deletions.
1 change: 1 addition & 0 deletions control-plane/src/modules/models/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ const mockCreate = jest.fn(() => ({
}));

jest.mock("./routing", () => ({
...jest.requireActual("./routing"),
getRouting: jest.fn(() => ({
buildClient: jest.fn(() => ({
messages: {
Expand Down
4 changes: 2 additions & 2 deletions control-plane/src/modules/models/routing.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import { BedrockCohereEmbeddings } from "../embeddings/bedrock-cohere-embeddings
import { CohereEmbeddings } from "@langchain/cohere";

export const CONTEXT_WINDOW: Record<string, number> = {
"claude-3-5-sonnet": 1000,
"claude-3-haiku": 1000,
"claude-3-5-sonnet": 200_000,
"claude-3-haiku": 200_000,
};

const routingOptions = {
Expand Down
30 changes: 12 additions & 18 deletions control-plane/src/modules/workflows/agent/nodes/model-call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ const _handleModelCall = async (
findRelevantTools: ReleventToolLookup
): Promise<WorkflowStateUpdate> => {
detectCycle(state.messages);
const relevantSchemas = await findRelevantTools(state);

const relevantTools = await findRelevantTools(state);

addAttributes({
'model.relevant_tools': relevantSchemas.map(tool => tool.name),
'model.relevant_tools': relevantTools.map(tool => tool.name),
'model.available_tools': state.allAvailableTools,
'model.identifier': model.identifier,
});
Expand All @@ -52,40 +53,33 @@ const _handleModelCall = async (

const schema = buildModelSchema({
state,
relevantSchemas,
relevantSchemas: relevantTools,
resultSchema: state.workflow.resultSchema as JsonSchemaInput,
});

const schemaString = relevantSchemas.map(tool => {
return `${tool.name} - ${tool.description} ${tool.schema}`;
});

const systemPrompt = getSystemPrompt(state, schemaString);
const systemPrompt = getSystemPrompt(state, relevantTools);

const trimmedMessages = await handleContextWindowOverflow({
modelContextWindow: model.contextWindow ?? 0,
systemPrompt,
const truncatedMessages = await handleContextWindowOverflow({
messages: state.messages,
render: toAnthropicMessage,
systemPrompt: systemPrompt + JSON.stringify(schema),
modelContextWindow: model.contextWindow,
render: (m) => JSON.stringify(toAnthropicMessage(m)),
});

const renderedMessages = toAnthropicMessages(trimmedMessages);

if (state.workflow.debug) {
addAttributes({
'model.input.additional_context': state.additionalContext,
'model.input.systemPrompt': systemPrompt,
'model.input.messages': JSON.stringify(
state.messages.map(m => ({
truncatedMessages.map(m => ({
id: m.id,
type: m.type,
}))
),
'model.input.rendered_messages': JSON.stringify(renderedMessages),
});
}

const response = await model.structured({
messages: renderedMessages,
messages: toAnthropicMessages(truncatedMessages),
system: systemPrompt,
schema,
});
Expand Down
10 changes: 7 additions & 3 deletions control-plane/src/modules/workflows/agent/nodes/system-prompt.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { WorkflowAgentState } from "../state";
import { AgentTool } from "../tool";

export const getSystemPrompt = (
state: WorkflowAgentState,
schemaString: string[],
tools: AgentTool[],
): string => {
const basePrompt = [
"You are a helpful assistant with access to a set of tools designed to assist in completing tasks.",
Expand Down Expand Up @@ -36,16 +37,19 @@ export const getSystemPrompt = (
basePrompt.push(state.additionalContext);
}


// Add tool schemas
basePrompt.push("<TOOLS_SCHEMAS>");
basePrompt.push(...schemaString);
basePrompt.push(...tools.map(tool => {
return `${tool.name} - ${tool.description} ${tool.schema}`;
}));
basePrompt.push("</TOOLS_SCHEMAS>");

// Add other available tools
basePrompt.push("<OTHER_AVAILABLE_TOOLS>");
basePrompt.push(
...state.allAvailableTools.filter(
(t) => !schemaString.find((s) => s.includes(t)),
(t) => !tools.find((s) => s.name === t),
),
);
basePrompt.push("</OTHER_AVAILABLE_TOOLS>");
Expand Down
54 changes: 39 additions & 15 deletions control-plane/src/modules/workflows/agent/overflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,67 @@ const TOTAL_CONTEXT_THRESHOLD = 0.95;
const SYSTEM_PROMPT_THRESHOLD = 0.7;

export const handleContextWindowOverflow = async ({
systemPrompt,
messages,
systemPrompt,
modelContextWindow,
render = JSON.stringify
}: {
systemPrompt: string
messages: WorkflowAgentStateMessage[]
modelContextWindow: number
systemPrompt: string
modelContextWindow?: number
render? (message: WorkflowAgentStateMessage): unknown
//strategy?: "truncate"
}) => {
if (!modelContextWindow) {
logger.warn("Model context window not set, defaulting to 100_000");
modelContextWindow = 100_000;
}

const systemPromptTokenCount = await estimateTokenCount(systemPrompt);

if (systemPromptTokenCount > modelContextWindow * SYSTEM_PROMPT_THRESHOLD) {
throw new AgentError(`System prompt can not exceed ${modelContextWindow * SYSTEM_PROMPT_THRESHOLD} tokens`);
}

let messagesTokenCount = await estimateTokenCount(messages.map(render).join("\n"));
if (messagesTokenCount + systemPromptTokenCount < (modelContextWindow * TOTAL_CONTEXT_THRESHOLD)) {
return messages;
}
const inputTokenCount = await estimateTokenCount(messages.map(render).join("\n"));
let messagesTokenCount = inputTokenCount;

logger.info("Chat history exceeds context window, early messages will be dropped", {
systemPromptTokenCount,
messagesTokenCount,
})
const removedMessages: WorkflowAgentStateMessage[] = [];

do {
// Remove messages until total tokens are under threshold
while (messages.length && messagesTokenCount + systemPromptTokenCount > modelContextWindow * TOTAL_CONTEXT_THRESHOLD) {
if (messages.length === 1) {
throw new AgentError("Single chat message exceeds context window");
logger.error("A single message exceeds context window", {
messageId: messages[0].id
});
throw new AgentError("Run state is invalid");
}

messages.shift();
const removed = messages.shift()
removed && removedMessages.push(removed);

messagesTokenCount = await estimateTokenCount(messages.map(render).join("\n"));
}

// First message should always be human
while (messages.length && messages[0].type !== "human") {
if (messages.length === 1) {
logger.error("Only message in mesasge history is not human", {
messageId: messages[0].id
});
throw new AgentError("Run state is invalid");
}

} while (messagesTokenCount + systemPromptTokenCount > modelContextWindow * TOTAL_CONTEXT_THRESHOLD || messages[0].type !== 'human');
const removed = messages.shift()
removed && removedMessages.push(removed);
}

logger.info("Run history exceeds context window, early messages have been truncated", {
removedMessageIds: removedMessages.map(m => m.id),
systemPromptTokenCount,
outputTokenCount: messagesTokenCount,
inputTokenCount,
})

return messages;
};
2 changes: 1 addition & 1 deletion control-plane/src/modules/workflows/workflow-messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ export const getRunMessagesForDisplay = async ({
displayableContext: workflowMessages.metadata,
})
.from(workflowMessages)
.orderBy(asc(workflowMessages.created_at))
.orderBy(desc(workflowMessages.created_at))
.where(
and(
eq(workflowMessages.cluster_id, clusterId),
Expand Down

0 comments on commit e8593c5

Please sign in to comment.