Skip to content

Commit

Permalink
feat(ai): Support tools from non-public classes/methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasVitale committed Jan 5, 2025
1 parent a24f152 commit 0c308dc
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public String call(String toolInput, @Nullable ToolContext toolContext) {

Object[] methodArguments = buildMethodArguments(toolArguments, toolContext);

Object result = ReflectionUtils.invokeMethod(toolMethod, toolObject, methodArguments);
Object result = callMethod(methodArguments);

Class<?> returnType = toolMethod.getReturnType();

Expand Down Expand Up @@ -114,6 +114,22 @@ private Object buildTypedArgument(@Nullable Object value, Class<?> type) {
return JsonParser.toTypedObject(value, type);
}

@Nullable
private Object callMethod(Object[] methodArguments) {
if (isObjectNotPublic() || isMethodNotPublic()) {
toolMethod.setAccessible(true);
}
return ReflectionUtils.invokeMethod(toolMethod, toolObject, methodArguments);
}

private boolean isObjectNotPublic() {
return !Modifier.isPublic(toolObject.getClass().getModifiers());
}

private boolean isMethodNotPublic() {
return !Modifier.isPublic(toolMethod.getModifiers());
}

// Based on the implementation in MethodInvokingFunctionCallback.
private String formatResult(@Nullable Object result, Class<?> returnType) {
if (returnType == Void.TYPE) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class MethodToolCallbackProviderTests {

@Test
void shouldProvideToolCallbacksFromObject() {
TestComponent testComponent = new TestComponent();
MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder().toolObjects(testComponent).build();
Tools tools = new Tools();
MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder().toolObjects(tools).build();

ToolCallback[] callbacks = provider.getToolCallbacks();

Expand All @@ -39,48 +39,48 @@ void shouldProvideToolCallbacksFromObject() {

@Test
void shouldEnsureUniqueToolNames() {
TestComponentWithDuplicates testComponent = new TestComponentWithDuplicates();
ToolsWithDuplicates testComponent = new ToolsWithDuplicates();
MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder().toolObjects(testComponent).build();

assertThatThrownBy(provider::getToolCallbacks).isInstanceOf(IllegalStateException.class)
.hasMessageContaining("Multiple tools with the same name (testMethod) found in sources: "
+ testComponent.getClass().getName());
}

static class TestComponent {
static class Tools {

@Tool("Test description")
public static List<String> testStaticMethod(String input) {
static List<String> testStaticMethod(String input) {
return List.of(input);
}

@Tool("Test description")
public List<String> testMethod(String input) {
List<String> testMethod(String input) {
return List.of(input);
}

@Tool("Test description")
public Function<String, Integer> testFunction(String input) {
Function<String, Integer> testFunction(String input) {
// This method should be ignored as it's a functional type, which is not
// supported.
return String::length;
}

public void nonToolMethod() {
void nonToolMethod() {
// This method should be ignored as it doesn't have @Tool annotation
}

}

static class TestComponentWithDuplicates {
static class ToolsWithDuplicates {

@Tool(name = "testMethod", value = "Test description")
public List<String> testMethod1(String input) {
List<String> testMethod1(String input) {
return List.of(input);
}

@Tool(name = "testMethod", value = "Test description")
public List<String> testMethod2(String input) {
List<String> testMethod2(String input) {
return List.of(input);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
package io.arconia.ai.core.tools.method;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;

import com.fasterxml.jackson.core.type.TypeReference;

import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.util.ReflectionUtils;

import io.arconia.ai.core.tools.annotation.Tool;
import io.arconia.ai.core.tools.json.JsonParser;
import io.arconia.ai.core.tools.metadata.ToolMetadata;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Unit tests for {@link MethodToolCallback}.
*/
class MethodToolCallbackTests {

@ParameterizedTest
@ValueSource(strings = {
"publicStaticMethod",
"privateStaticMethod",
"packageStaticMethod",
"publicMethod",
"privateMethod",
"packageMethod"
})
void shouldCallToolFromPublicClass(String methodName) {
validateAssertions(methodName, new PublicTools());
}

@ParameterizedTest
@ValueSource(strings = {
"publicStaticMethod",
"privateStaticMethod",
"packageStaticMethod",
"publicMethod",
"privateMethod",
"packageMethod"
})
void shouldCallToolFromPrivateClass(String methodName) {
validateAssertions(methodName, new PrivateTools());
}

@ParameterizedTest
@ValueSource(strings = {
"publicStaticMethod",
"privateStaticMethod",
"packageStaticMethod",
"publicMethod",
"privateMethod",
"packageMethod"
})
void shouldCallToolFromPackageClass(String methodName) {
validateAssertions(methodName, new PackageTools());
}

private static void validateAssertions(String methodName, Object toolObject) {
Method toolMethod = getMethod(methodName, toolObject.getClass());
assertThat(toolMethod).isNotNull();
MethodToolCallback callback = MethodToolCallback.builder()
.toolMetadata(ToolMetadata.from(toolMethod))
.toolMethod(toolMethod)
.toolObject(toolObject)
.build();

String result = callback.call("""
{
"input": "Wingardium Leviosa"
}
""");

assertThat(JsonParser.fromJson(result, new TypeReference<List<String>>() {}))
.contains("Wingardium Leviosa");
}

private static Method getMethod(String name, Class<?> toolsClass) {
return Arrays.stream(ReflectionUtils.getDeclaredMethods(toolsClass))
.filter(m -> m.getName().equals(name))
.findFirst()
.orElseThrow();
}

static public class PublicTools {

@Tool("Test description")
public static List<String> publicStaticMethod(String input) {
return List.of(input);
}

@Tool("Test description")
private static List<String> privateStaticMethod(String input) {
return List.of(input);
}

@Tool("Test description")
static List<String> packageStaticMethod(String input) {
return List.of(input);
}

@Tool("Test description")
public List<String> publicMethod(String input) {
return List.of(input);
}

@Tool("Test description")
private List<String> privateMethod(String input) {
return List.of(input);
}

@Tool("Test description")
List<String> packageMethod(String input) {
return List.of(input);
}

}

static private class PrivateTools {

@Tool("Test description")
public static List<String> publicStaticMethod(String input) {
return List.of(input);
}

@Tool("Test description")
private static List<String> privateStaticMethod(String input) {
return List.of(input);
}

@Tool("Test description")
static List<String> packageStaticMethod(String input) {
return List.of(input);
}

@Tool("Test description")
public List<String> publicMethod(String input) {
return List.of(input);
}

@Tool("Test description")
private List<String> privateMethod(String input) {
return List.of(input);
}

@Tool("Test description")
List<String> packageMethod(String input) {
return List.of(input);
}

}

static class PackageTools {

@Tool("Test description")
public static List<String> publicStaticMethod(String input) {
return List.of(input);
}

@Tool("Test description")
private static List<String> privateStaticMethod(String input) {
return List.of(input);
}

@Tool("Test description")
static List<String> packageStaticMethod(String input) {
return List.of(input);
}

@Tool("Test description")
public List<String> publicMethod(String input) {
return List.of(input);
}

@Tool("Test description")
private List<String> privateMethod(String input) {
return List.of(input);
}

@Tool("Test description")
List<String> packageMethod(String input) {
return List.of(input);
}

}

}

0 comments on commit 0c308dc

Please sign in to comment.