From 0c308dc0c128db4d38bc963caa91c525e1718aa0 Mon Sep 17 00:00:00 2001 From: Thomas Vitale Date: Sun, 5 Jan 2025 12:07:17 +0100 Subject: [PATCH] feat(ai): Support tools from non-public classes/methods --- .../core/tools/method/MethodToolCallback.java | 18 +- .../MethodToolCallbackProviderTests.java | 22 +- .../tools/method/MethodToolCallbackTests.java | 191 ++++++++++++++++++ 3 files changed, 219 insertions(+), 12 deletions(-) create mode 100644 arconia-ai/arconia-ai-core/src/test/java/io/arconia/ai/core/tools/method/MethodToolCallbackTests.java diff --git a/arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/tools/method/MethodToolCallback.java b/arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/tools/method/MethodToolCallback.java index 6767032..cee3c2e 100644 --- a/arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/tools/method/MethodToolCallback.java +++ b/arconia-ai/arconia-ai-core/src/main/java/io/arconia/ai/core/tools/method/MethodToolCallback.java @@ -75,7 +75,7 @@ public String call(String toolInput, @Nullable ToolContext toolContext) { Object[] methodArguments = buildMethodArguments(toolArguments, toolContext); - Object result = ReflectionUtils.invokeMethod(toolMethod, toolObject, methodArguments); + Object result = callMethod(methodArguments); Class returnType = toolMethod.getReturnType(); @@ -114,6 +114,22 @@ private Object buildTypedArgument(@Nullable Object value, Class type) { return JsonParser.toTypedObject(value, type); } + @Nullable + private Object callMethod(Object[] methodArguments) { + if (isObjectNotPublic() || isMethodNotPublic()) { + toolMethod.setAccessible(true); + } + return ReflectionUtils.invokeMethod(toolMethod, toolObject, methodArguments); + } + + private boolean isObjectNotPublic() { + return !Modifier.isPublic(toolObject.getClass().getModifiers()); + } + + private boolean isMethodNotPublic() { + return !Modifier.isPublic(toolMethod.getModifiers()); + } + // Based on the implementation in MethodInvokingFunctionCallback. private String formatResult(@Nullable Object result, Class returnType) { if (returnType == Void.TYPE) { diff --git a/arconia-ai/arconia-ai-core/src/test/java/io/arconia/ai/core/tools/method/MethodToolCallbackProviderTests.java b/arconia-ai/arconia-ai-core/src/test/java/io/arconia/ai/core/tools/method/MethodToolCallbackProviderTests.java index d5f53fc..cc4665e 100644 --- a/arconia-ai/arconia-ai-core/src/test/java/io/arconia/ai/core/tools/method/MethodToolCallbackProviderTests.java +++ b/arconia-ai/arconia-ai-core/src/test/java/io/arconia/ai/core/tools/method/MethodToolCallbackProviderTests.java @@ -19,8 +19,8 @@ class MethodToolCallbackProviderTests { @Test void shouldProvideToolCallbacksFromObject() { - TestComponent testComponent = new TestComponent(); - MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder().toolObjects(testComponent).build(); + Tools tools = new Tools(); + MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder().toolObjects(tools).build(); ToolCallback[] callbacks = provider.getToolCallbacks(); @@ -39,7 +39,7 @@ void shouldProvideToolCallbacksFromObject() { @Test void shouldEnsureUniqueToolNames() { - TestComponentWithDuplicates testComponent = new TestComponentWithDuplicates(); + ToolsWithDuplicates testComponent = new ToolsWithDuplicates(); MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder().toolObjects(testComponent).build(); assertThatThrownBy(provider::getToolCallbacks).isInstanceOf(IllegalStateException.class) @@ -47,40 +47,40 @@ void shouldEnsureUniqueToolNames() { + testComponent.getClass().getName()); } - static class TestComponent { + static class Tools { @Tool("Test description") - public static List testStaticMethod(String input) { + static List testStaticMethod(String input) { return List.of(input); } @Tool("Test description") - public List testMethod(String input) { + List testMethod(String input) { return List.of(input); } @Tool("Test description") - public Function testFunction(String input) { + Function testFunction(String input) { // This method should be ignored as it's a functional type, which is not // supported. return String::length; } - public void nonToolMethod() { + void nonToolMethod() { // This method should be ignored as it doesn't have @Tool annotation } } - static class TestComponentWithDuplicates { + static class ToolsWithDuplicates { @Tool(name = "testMethod", value = "Test description") - public List testMethod1(String input) { + List testMethod1(String input) { return List.of(input); } @Tool(name = "testMethod", value = "Test description") - public List testMethod2(String input) { + List testMethod2(String input) { return List.of(input); } diff --git a/arconia-ai/arconia-ai-core/src/test/java/io/arconia/ai/core/tools/method/MethodToolCallbackTests.java b/arconia-ai/arconia-ai-core/src/test/java/io/arconia/ai/core/tools/method/MethodToolCallbackTests.java new file mode 100644 index 0000000..309d992 --- /dev/null +++ b/arconia-ai/arconia-ai-core/src/test/java/io/arconia/ai/core/tools/method/MethodToolCallbackTests.java @@ -0,0 +1,191 @@ +package io.arconia.ai.core.tools.method; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.List; + +import com.fasterxml.jackson.core.type.TypeReference; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.util.ReflectionUtils; + +import io.arconia.ai.core.tools.annotation.Tool; +import io.arconia.ai.core.tools.json.JsonParser; +import io.arconia.ai.core.tools.metadata.ToolMetadata; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link MethodToolCallback}. + */ +class MethodToolCallbackTests { + + @ParameterizedTest + @ValueSource(strings = { + "publicStaticMethod", + "privateStaticMethod", + "packageStaticMethod", + "publicMethod", + "privateMethod", + "packageMethod" + }) + void shouldCallToolFromPublicClass(String methodName) { + validateAssertions(methodName, new PublicTools()); + } + + @ParameterizedTest + @ValueSource(strings = { + "publicStaticMethod", + "privateStaticMethod", + "packageStaticMethod", + "publicMethod", + "privateMethod", + "packageMethod" + }) + void shouldCallToolFromPrivateClass(String methodName) { + validateAssertions(methodName, new PrivateTools()); + } + + @ParameterizedTest + @ValueSource(strings = { + "publicStaticMethod", + "privateStaticMethod", + "packageStaticMethod", + "publicMethod", + "privateMethod", + "packageMethod" + }) + void shouldCallToolFromPackageClass(String methodName) { + validateAssertions(methodName, new PackageTools()); + } + + private static void validateAssertions(String methodName, Object toolObject) { + Method toolMethod = getMethod(methodName, toolObject.getClass()); + assertThat(toolMethod).isNotNull(); + MethodToolCallback callback = MethodToolCallback.builder() + .toolMetadata(ToolMetadata.from(toolMethod)) + .toolMethod(toolMethod) + .toolObject(toolObject) + .build(); + + String result = callback.call(""" + { + "input": "Wingardium Leviosa" + } + """); + + assertThat(JsonParser.fromJson(result, new TypeReference>() {})) + .contains("Wingardium Leviosa"); + } + + private static Method getMethod(String name, Class toolsClass) { + return Arrays.stream(ReflectionUtils.getDeclaredMethods(toolsClass)) + .filter(m -> m.getName().equals(name)) + .findFirst() + .orElseThrow(); + } + + static public class PublicTools { + + @Tool("Test description") + public static List publicStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + private static List privateStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + static List packageStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + public List publicMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + private List privateMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + List packageMethod(String input) { + return List.of(input); + } + + } + + static private class PrivateTools { + + @Tool("Test description") + public static List publicStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + private static List privateStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + static List packageStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + public List publicMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + private List privateMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + List packageMethod(String input) { + return List.of(input); + } + + } + + static class PackageTools { + + @Tool("Test description") + public static List publicStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + private static List privateStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + static List packageStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + public List publicMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + private List privateMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + List packageMethod(String input) { + return List.of(input); + } + + } + +}