From 3981a66aff458d97478aa822c3070c053570fbcf Mon Sep 17 00:00:00 2001 From: Thomas Vitale Date: Tue, 7 Jan 2025 01:01:25 +0100 Subject: [PATCH] refactor(ai): Improve tool metadata handling --- .../arconia/ai/core/tools/ToolCallback.java | 26 +++++++++++++- .../core/tools/method/MethodToolCallback.java | 16 +-------- .../arconia/ai/core/tools/ToolUtilsTests.java | 36 ++++++------------- .../arconia/ai/mcp/tools/McpToolCallback.java | 35 +++++++++++++++--- .../ai/mcp/tools/McpToolCallbackProvider.java | 2 +- 5 files changed, 68 insertions(+), 47 deletions(-) diff --git a/arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/tools/ToolCallback.java b/arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/tools/ToolCallback.java index 2b2d4e3..704f6e0 100644 --- a/arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/tools/ToolCallback.java +++ b/arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/tools/ToolCallback.java @@ -2,7 +2,31 @@ import org.springframework.ai.model.function.FunctionCallback; +import io.arconia.ai.core.tools.metadata.ToolMetadata; + /** * Specialization of {@link FunctionCallback} to identify tools in Spring AI. */ -public interface ToolCallback extends FunctionCallback {} +public interface ToolCallback extends FunctionCallback { + + /** + * Metadata for the tool. + */ + ToolMetadata getToolMetadata(); + + @Override + default String getName() { + return getToolMetadata().name(); + } + + @Override + default String getDescription() { + return getToolMetadata().description(); + } + + @Override + default String getInputTypeSchema() { + return getToolMetadata().inputTypeSchema(); + } + +} diff --git a/arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/tools/method/MethodToolCallback.java b/arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/tools/method/MethodToolCallback.java index cee3c2e..5acee9a 100644 --- a/arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/tools/method/MethodToolCallback.java +++ b/arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/tools/method/MethodToolCallback.java @@ -41,25 +41,11 @@ public MethodToolCallback(ToolMetadata toolMetadata, Method toolMethod, @Nullabl this.toolObject = toolObject; } + @Override public ToolMetadata getToolMetadata() { return toolMetadata; } - @Override - public String getName() { - return toolMetadata.name(); - } - - @Override - public String getDescription() { - return toolMetadata.description(); - } - - @Override - public String getInputTypeSchema() { - return toolMetadata.inputTypeSchema(); - } - @Override public String call(String toolInput) { return call(toolInput, null); diff --git a/arconia-ai/arconia-ai-core/src/test/java/io/arconia/ai/core/tools/ToolUtilsTests.java b/arconia-ai/arconia-ai-core/src/test/java/io/arconia/ai/core/tools/ToolUtilsTests.java index 0dc7e53..32006dc 100644 --- a/arconia-ai/arconia-ai-core/src/test/java/io/arconia/ai/core/tools/ToolUtilsTests.java +++ b/arconia-ai/arconia-ai-core/src/test/java/io/arconia/ai/core/tools/ToolUtilsTests.java @@ -4,6 +4,8 @@ import org.junit.jupiter.api.Test; +import io.arconia.ai.core.tools.metadata.DefaultToolMetadata; +import io.arconia.ai.core.tools.metadata.ToolMetadata; import io.arconia.ai.core.tools.util.ToolUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -38,37 +40,19 @@ void shouldNotDetectDuplicateToolNames() { static class TestToolCallback implements ToolCallback { - private final String name; - - private final String description; - - private final String inputTypeSchema; + private final ToolMetadata toolMetadata; public TestToolCallback(String name) { - this.name = name; - this.description = ""; - this.inputTypeSchema = ""; - } - - public TestToolCallback(String name, String description, String inputTypeSchema) { - this.name = name; - this.description = description; - this.inputTypeSchema = inputTypeSchema; - } - - @Override - public String getName() { - return name; - } - - @Override - public String getDescription() { - return description; + this.toolMetadata = DefaultToolMetadata.builder() + .name(name) + .description(name) + .inputTypeSchema("{}") + .build(); } @Override - public String getInputTypeSchema() { - return inputTypeSchema; + public ToolMetadata getToolMetadata() { + return toolMetadata; } @Override diff --git a/arconia-ai/arconia-ai-mcp/src/main/java/io/arconia/ai/mcp/tools/McpToolCallback.java b/arconia-ai/arconia-ai-mcp/src/main/java/io/arconia/ai/mcp/tools/McpToolCallback.java index 26e866e..b8753f3 100644 --- a/arconia-ai/arconia-ai-mcp/src/main/java/io/arconia/ai/mcp/tools/McpToolCallback.java +++ b/arconia-ai/arconia-ai-mcp/src/main/java/io/arconia/ai/mcp/tools/McpToolCallback.java @@ -1,18 +1,45 @@ package io.arconia.ai.mcp.tools; +import java.util.Map; + +import com.fasterxml.jackson.core.type.TypeReference; + import org.springframework.ai.mcp.client.McpSyncClient; import org.springframework.ai.mcp.spec.McpSchema; -import org.springframework.ai.mcp.spring.McpFunctionCallback; +import org.springframework.ai.model.ModelOptionsUtils; import io.arconia.ai.core.tools.ToolCallback; +import io.arconia.ai.core.tools.json.JsonParser; +import io.arconia.ai.core.tools.metadata.DefaultToolMetadata; +import io.arconia.ai.core.tools.metadata.ToolMetadata; /** * A {@link ToolCallback} for handling calls to MCP tools. */ -public class McpToolCallback extends McpFunctionCallback implements ToolCallback { +public class McpToolCallback implements ToolCallback { + + private final ToolMetadata toolMetadata; + private final McpSyncClient mcpClient; + + public McpToolCallback(McpSchema.Tool tool, McpSyncClient mcpClient) { + this.toolMetadata = DefaultToolMetadata.builder() + .name(tool.name()) + .description(tool.description()) + .inputTypeSchema(JsonParser.toJson(tool.inputSchema())) + .build(); + this.mcpClient = mcpClient; + } + + @Override + public ToolMetadata getToolMetadata() { + return this.toolMetadata; + } - public McpToolCallback(McpSyncClient mcpClient, McpSchema.Tool tool) { - super(mcpClient, tool); + @Override + public String call(String toolInput) { + Map arguments = JsonParser.fromJson(toolInput, new TypeReference<>() {}); + McpSchema.CallToolResult response = this.mcpClient.callTool(new McpSchema.CallToolRequest(this.getName(), arguments)); + return ModelOptionsUtils.toJsonString(response.content()); } } diff --git a/arconia-ai/arconia-ai-mcp/src/main/java/io/arconia/ai/mcp/tools/McpToolCallbackProvider.java b/arconia-ai/arconia-ai-mcp/src/main/java/io/arconia/ai/mcp/tools/McpToolCallbackProvider.java index c22ed54..870abcb 100644 --- a/arconia-ai/arconia-ai-mcp/src/main/java/io/arconia/ai/mcp/tools/McpToolCallbackProvider.java +++ b/arconia-ai/arconia-ai-mcp/src/main/java/io/arconia/ai/mcp/tools/McpToolCallbackProvider.java @@ -32,7 +32,7 @@ public ToolCallback[] getToolCallbacks() { .flatMap(mcpClient -> mcpClient.listTools() .tools() .stream() - .map(tool -> (ToolCallback) new McpToolCallback(mcpClient, tool))) + .map(tool -> new McpToolCallback(tool, mcpClient))) .toArray(ToolCallback[]::new); validateToolCallbacks(toolCallbacks);