Skip to content

Commit

Permalink
feat(ai): Support tools via Tool annotation and MCP
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasVitale committed Dec 31, 2024
1 parent 1741bec commit b051bf3
Show file tree
Hide file tree
Showing 14 changed files with 352 additions and 2 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Arconia

Arconia is a framework to build SaaS, multitenant applications using Java and Spring Boot.
Arconia is a framework to build modern applications using Java and Spring Boot.
It provides support for multitenancy and AI.

<img src="arconia-logo.png" alt="The Arconia logo" height="250px" />

Expand Down
30 changes: 30 additions & 0 deletions arconia-ai/arconia-ai-core/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
plugins {
id 'code-quality-conventions'
id 'java-conventions'
id 'sbom-conventions'
id 'release-conventions'
}

repositories {
maven { url 'https://repo.spring.io/milestone' }
}

dependencies {
implementation "org.slf4j:slf4j-api"
implementation "org.springframework:spring-context"
compileOnly "org.springframework.ai:spring-ai-core:${springAiVersion}"

testImplementation "org.springframework.boot:spring-boot-starter-test"
testImplementation "org.springframework.ai:spring-ai-core:${springAiVersion}"
}

publishing {
publications {
mavenJava(MavenPublication) {
pom {
name = "Arconia AI Core"
description = "Arconia AI Core."
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
@NonNullApi
@NonNullFields
package io.arconia.ai.core;

import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.arconia.ai.core.tools;

import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

import org.springframework.ai.model.function.FunctionCallback;

/**
* Annotation to mark a method as a tool.
*/
@Target({ ElementType.METHOD, ElementType.ANNOTATION_TYPE })
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Tool {

/**
* The description of the tool. If not provided, the method name will be used.
*/
String value() default "";

/**
* The name of the tool. If not provided, the method name will be used.
*/
String name() default "";

/**
* The schema type of the tool. JSON Schema will work for most cases. Vertex AI
* requires OpenAPI Schema.
*/
FunctionCallback.SchemaType schemaType() default FunctionCallback.SchemaType.JSON_SCHEMA;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package io.arconia.ai.core.tools;

import org.springframework.ai.model.function.FunctionCallback;

/**
* Wrapper for {@link FunctionCallback} to identify tools.
*/
public interface ToolCallback extends FunctionCallback {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package io.arconia.ai.core.tools;

import org.springframework.ai.model.function.FunctionCallback;

/**
* Resolves {@link ToolCallback} instances from different sources.
*/
public interface ToolCallbackResolver {

FunctionCallback[] getToolCallbacks();

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package io.arconia.ai.core.tools.method;

import java.util.stream.Stream;

import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;

import io.arconia.ai.core.tools.Tool;
import io.arconia.ai.core.tools.ToolCallback;
import io.arconia.ai.core.tools.ToolCallbackResolver;

/**
* A {@link ToolCallbackResolver} that resolves {@link ToolCallback} instances from
* methods annotated with {@link Tool}.
*/
public class MethodToolCallbackResolver implements ToolCallbackResolver {

private final Object target;

private MethodToolCallbackResolver(Object target) {
Assert.notNull(target, "target cannot be null");
this.target = target;
}

@Override
public FunctionCallback[] getToolCallbacks() {
return Stream.of(ReflectionUtils.getDeclaredMethods(target.getClass()))
.filter(method -> method.isAnnotationPresent(Tool.class))
.map(method -> FunctionCallback.builder()
.method(method.getName(), method.getParameterTypes())
.name(getToolName(method.getAnnotation(Tool.class), method.getName()))
.description(getToolDescription(method.getAnnotation(Tool.class), method.getName()))
.schemaType(method.getAnnotation(Tool.class).schemaType())
.targetObject(target)
.build())
.toArray(FunctionCallback[]::new);
}

private static String getToolName(Tool tool, String methodName) {
return StringUtils.hasText(tool.name()) ? tool.name() : methodName;
}

private static String getToolDescription(Tool tool, String methodName) {
return StringUtils.hasText(tool.value()) ? tool.value() : methodName;
}

public static class Builder {

private Object target;

public Builder target(Object target) {
this.target = target;
return this;
}

public MethodToolCallbackResolver build() {
return new MethodToolCallbackResolver(target);
}

}

public static Builder builder() {
return new Builder();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package io.arconia.ai.core.tools.method;

import io.arconia.ai.core.tools.Tool;

import org.junit.jupiter.api.Test;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.MethodInvokingFunctionCallback;

import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/**
* Unit tests for {@link MethodToolCallbackResolver}.
*/
class MethodToolCallbackResolverTests {

@Test
void shouldResolveToolCallbacks() {
TestComponent testComponent = new TestComponent();
MethodToolCallbackResolver resolver = MethodToolCallbackResolver.builder().target(testComponent).build();

FunctionCallback[] callbacks = resolver.getToolCallbacks();

assertThat(callbacks).hasSize(1);
MethodInvokingFunctionCallback callback = (MethodInvokingFunctionCallback) callbacks[0];
assertThat(callback.getName()).isEqualTo("testMethod");
assertThat(callback.getDescription()).isEqualTo("Test description");
}

@Test
void shouldFailWhenTargetIsNotProvided() {
assertThatThrownBy(() -> MethodToolCallbackResolver.builder().build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("target cannot be null");
}

static class TestComponent {

@Tool("Test description")
public List<String> testMethod(String input) {
return List.of(input);
}

public void nonToolMethod() {
// This method should be ignored as it doesn't have @Tool annotation
}

}

}
35 changes: 35 additions & 0 deletions arconia-ai/arconia-ai-mcp/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
plugins {
id 'code-quality-conventions'
id 'java-conventions'
id 'sbom-conventions'
id 'release-conventions'
}

repositories {
maven { url 'https://repo.spring.io/milestone' }
}

dependencies {
implementation project(":arconia-ai:arconia-ai-core")

implementation "org.slf4j:slf4j-api"
implementation "org.springframework:spring-context"

compileOnly "org.springframework.ai:spring-ai-core:${springAiVersion}"
compileOnly "org.springframework.experimental:spring-ai-mcp:${springAiMcpVersion}"

testImplementation "org.springframework.boot:spring-boot-starter-test"
testImplementation "org.springframework.ai:spring-ai-core:${springAiVersion}"
testImplementation "org.springframework.experimental:spring-ai-mcp:${springAiMcpVersion}"
}

publishing {
publications {
mavenJava(MavenPublication) {
pom {
name = "Arconia AI MCP"
description = "Arconia AI MCP."
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
@NonNullApi
@NonNullFields
package io.arconia.ai.mcp;

import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package io.arconia.ai.mcp.tools;

import org.springframework.ai.mcp.client.McpSyncClient;
import org.springframework.ai.mcp.spec.McpSchema;
import org.springframework.ai.mcp.spring.McpFunctionCallback;

import io.arconia.ai.core.tools.ToolCallback;

/**
* A {@link ToolCallback} for handling calls to MCP tools.
*/
public class McpToolCallback extends McpFunctionCallback implements ToolCallback {

public McpToolCallback(McpSyncClient clientSession, McpSchema.Tool tool) {
super(clientSession, tool);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package io.arconia.ai.mcp.tools;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream;

import org.springframework.ai.mcp.client.McpAsyncClient;
import org.springframework.ai.mcp.client.McpSyncClient;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.util.Assert;

import io.arconia.ai.core.tools.ToolCallback;
import io.arconia.ai.core.tools.ToolCallbackResolver;

/**
* A {@link ToolCallbackResolver} that resolves {@link ToolCallback} instances from MCP
* tools.
*/
public class McpToolCallbackResolver implements ToolCallbackResolver {

private final List<McpSyncClient> mcpClients;

public McpToolCallbackResolver(List<McpSyncClient> mcpClients) {
Assert.notNull(mcpClients, "mcpClients cannot be null");
Assert.noNullElements(mcpClients, "mcpClients cannot contain null elements");
this.mcpClients = mcpClients;
}

@Override
public FunctionCallback[] getToolCallbacks() {
return mcpClients.stream()
.flatMap(mcpClient -> mcpClient.listTools()
.tools()
.stream()
.map(tool -> (ToolCallback) new McpToolCallback(mcpClient, tool)))
.toArray(ToolCallback[]::new);
}

public static Builder builder() {
return new Builder();
}

public static class Builder {

private List<McpSyncClient> mcpClients;

public Builder mcpClients(List<McpSyncClient> mcpClients) {
this.mcpClients = mcpClients;
return this;
}

public Builder mcpClients(McpSyncClient... mcpClients) {
Assert.notNull(mcpClients, "mcpClients cannot be null");
this.mcpClients = Arrays.asList(mcpClients);
return this;
}

public Builder mcpClients(McpAsyncClient... mcpClients) {
this.mcpClients = Stream.of(mcpClients).map(McpSyncClient::new).toList();
return this;
}

public McpToolCallbackResolver build() {
return new McpToolCallbackResolver(mcpClients);
}

}

}
4 changes: 4 additions & 0 deletions gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@ group=io.arconia
version=0.3.0-SNAPSHOT

org.gradle.parallel=true

# Dependency Versions
springAiVersion=1.0.0-M5
springAiMcpVersion=0.2.0
6 changes: 5 additions & 1 deletion settings.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@ plugins {

rootProject.name = 'arconia'

// AI
include 'arconia-ai:arconia-ai-core'
include 'arconia-ai:arconia-ai-mcp'

// Multitenancy
include 'arconia-multitenancy:arconia-multitenancy-core'
include 'arconia-multitenancy:arconia-multitenancy-web'

include 'arconia-multitenancy:arconia-multitenancy-spring-boot-autoconfigure'
include 'arconia-multitenancy:arconia-multitenancy-spring-boot-starters:arconia-multitenancy-web-spring-boot-starter'

0 comments on commit b051bf3

Please sign in to comment.