Skip to content

Commit

Permalink
refactor(ai): Improve tool metadata handling
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasVitale committed Jan 7, 2025
1 parent d1cf816 commit 3981a66
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,31 @@

import org.springframework.ai.model.function.FunctionCallback;

import io.arconia.ai.core.tools.metadata.ToolMetadata;

/**
* Specialization of {@link FunctionCallback} to identify tools in Spring AI.
*/
public interface ToolCallback extends FunctionCallback {}
public interface ToolCallback extends FunctionCallback {

/**
* Metadata for the tool.
*/
ToolMetadata getToolMetadata();

@Override
default String getName() {
return getToolMetadata().name();
}

@Override
default String getDescription() {
return getToolMetadata().description();
}

@Override
default String getInputTypeSchema() {
return getToolMetadata().inputTypeSchema();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,11 @@ public MethodToolCallback(ToolMetadata toolMetadata, Method toolMethod, @Nullabl
this.toolObject = toolObject;
}

@Override
public ToolMetadata getToolMetadata() {
return toolMetadata;
}

@Override
public String getName() {
return toolMetadata.name();
}

@Override
public String getDescription() {
return toolMetadata.description();
}

@Override
public String getInputTypeSchema() {
return toolMetadata.inputTypeSchema();
}

@Override
public String call(String toolInput) {
return call(toolInput, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import org.junit.jupiter.api.Test;

import io.arconia.ai.core.tools.metadata.DefaultToolMetadata;
import io.arconia.ai.core.tools.metadata.ToolMetadata;
import io.arconia.ai.core.tools.util.ToolUtils;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -38,37 +40,19 @@ void shouldNotDetectDuplicateToolNames() {

static class TestToolCallback implements ToolCallback {

private final String name;

private final String description;

private final String inputTypeSchema;
private final ToolMetadata toolMetadata;

public TestToolCallback(String name) {
this.name = name;
this.description = "";
this.inputTypeSchema = "";
}

public TestToolCallback(String name, String description, String inputTypeSchema) {
this.name = name;
this.description = description;
this.inputTypeSchema = inputTypeSchema;
}

@Override
public String getName() {
return name;
}

@Override
public String getDescription() {
return description;
this.toolMetadata = DefaultToolMetadata.builder()
.name(name)
.description(name)
.inputTypeSchema("{}")
.build();
}

@Override
public String getInputTypeSchema() {
return inputTypeSchema;
public ToolMetadata getToolMetadata() {
return toolMetadata;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,45 @@
package io.arconia.ai.mcp.tools;

import java.util.Map;

import com.fasterxml.jackson.core.type.TypeReference;

import org.springframework.ai.mcp.client.McpSyncClient;
import org.springframework.ai.mcp.spec.McpSchema;
import org.springframework.ai.mcp.spring.McpFunctionCallback;
import org.springframework.ai.model.ModelOptionsUtils;

import io.arconia.ai.core.tools.ToolCallback;
import io.arconia.ai.core.tools.json.JsonParser;
import io.arconia.ai.core.tools.metadata.DefaultToolMetadata;
import io.arconia.ai.core.tools.metadata.ToolMetadata;

/**
* A {@link ToolCallback} for handling calls to MCP tools.
*/
public class McpToolCallback extends McpFunctionCallback implements ToolCallback {
public class McpToolCallback implements ToolCallback {

private final ToolMetadata toolMetadata;
private final McpSyncClient mcpClient;

public McpToolCallback(McpSchema.Tool tool, McpSyncClient mcpClient) {
this.toolMetadata = DefaultToolMetadata.builder()
.name(tool.name())
.description(tool.description())
.inputTypeSchema(JsonParser.toJson(tool.inputSchema()))
.build();
this.mcpClient = mcpClient;
}

@Override
public ToolMetadata getToolMetadata() {
return this.toolMetadata;
}

public McpToolCallback(McpSyncClient mcpClient, McpSchema.Tool tool) {
super(mcpClient, tool);
@Override
public String call(String toolInput) {
Map<String, Object> arguments = JsonParser.fromJson(toolInput, new TypeReference<>() {});
McpSchema.CallToolResult response = this.mcpClient.callTool(new McpSchema.CallToolRequest(this.getName(), arguments));
return ModelOptionsUtils.toJsonString(response.content());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public ToolCallback[] getToolCallbacks() {
.flatMap(mcpClient -> mcpClient.listTools()
.tools()
.stream()
.map(tool -> (ToolCallback) new McpToolCallback(mcpClient, tool)))
.map(tool -> new McpToolCallback(tool, mcpClient)))
.toArray(ToolCallback[]::new);

validateToolCallbacks(toolCallbacks);
Expand Down

0 comments on commit 3981a66

Please sign in to comment.