Skip to content

Commit

Permalink
feat(ai): More robust and extensive tool calling support based on met…
Browse files Browse the repository at this point in the history
…hods
  • Loading branch information
ThomasVitale committed Jan 11, 2025
1 parent 6393ae6 commit 643fb10
Show file tree
Hide file tree
Showing 32 changed files with 2,070 additions and 485 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
import io.arconia.ai.tools.metadata.ToolMetadata;

/**
* Specialization of {@link FunctionCallback} to identify tools in Spring AI.
* Represents a tool whose execution can be triggered by an AI model.
*/
public interface ToolCallback extends FunctionCallback {

/**
* Definition of the tool.
* Definition used by the AI model to determine when and how to call the tool.
*/
ToolDefinition getToolDefinition();

/**
* Metadata for the tool.
* Metadata providing additional information on how to handle the tool.
*/
default ToolMetadata getToolMetadata() {
return ToolMetadata.builder().build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

import io.arconia.ai.tools.execution.DefaultToolCallResultConverter;
import io.arconia.ai.tools.execution.ToolCallResultConverter;
import io.arconia.ai.tools.execution.ToolExecutionMode;

/**
Expand Down Expand Up @@ -36,4 +38,9 @@
*/
boolean returnDirect() default false;

/**
* The class to use to convert the tool call result to a String.
*/
Class<? extends ToolCallResultConverter> resultConverter() default DefaultToolCallResultConverter.class;

}
Original file line number Diff line number Diff line change
@@ -1,29 +1,16 @@
package io.arconia.ai.tools.definition;

import java.lang.reflect.Method;

import org.springframework.util.Assert;

import io.arconia.ai.tools.json.JsonSchemaGenerator;
import io.arconia.ai.tools.util.ToolUtils;

/**
* Default implementation of {@link ToolDefinition}.
*/
public record DefaultToolDefinition(String name, String description, String inputTypeSchema) implements ToolDefinition {

public DefaultToolDefinition {
Assert.hasText(name, "name cannot be null");
Assert.hasText(description, "description cannot be null");
Assert.hasText(inputTypeSchema, "inputTypeSchema cannot be null");
}

static DefaultToolDefinition from(Method method) {
return DefaultToolDefinition.builder()
.name(ToolUtils.getToolName(method))
.description(ToolUtils.getToolDescription(method))
.inputTypeSchema(JsonSchemaGenerator.generate(method))
.build();
Assert.hasText(name, "name cannot be null or empty");
Assert.hasText(description, "description cannot be null or empty");
Assert.hasText(inputTypeSchema, "inputTypeSchema cannot be null or empty");
}

public static Builder builder() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import java.lang.reflect.Method;

import io.arconia.ai.tools.json.JsonSchemaGenerator;
import io.arconia.ai.tools.util.ToolUtils;

/**
* Definition of a tool that can be used by a model.
* Definition used by the AI model to determine when and how to call the tool.
*/
public interface ToolDefinition {

Expand All @@ -13,7 +16,7 @@ public interface ToolDefinition {
String name();

/**
* The tool description, used by the model to decide if and when to use the tool.
* The tool description, used by the AI model to determine what the tool does.
*/
String description();

Expand All @@ -30,10 +33,14 @@ static DefaultToolDefinition.Builder builder() {
}

/**
* Create {@link ToolDefinition} from a {@link Method}.
* Create a default {@link ToolDefinition} instance from a {@link Method}.
*/
static ToolDefinition from(Method method) {
return DefaultToolDefinition.from(method);
return DefaultToolDefinition.builder()
.name(ToolUtils.getToolName(method))
.description(ToolUtils.getToolDescription(method))
.inputTypeSchema(JsonSchemaGenerator.generateForMethodInput(method))
.build();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package io.arconia.ai.tools.execution;

import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

import io.arconia.ai.tools.json.JsonParser;

/**
* A default implementation of {@link ToolCallResultConverter}.
*/
public class DefaultToolCallResultConverter implements ToolCallResultConverter {

@Override
public String apply(@Nullable Object result, Class<?> returnType) {
Assert.notNull(returnType, "returnType cannot be null");
if (returnType == Void.TYPE) {
return "Done";
} else {
return JsonParser.toJson(result);
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package io.arconia.ai.tools.execution;

import java.util.function.BiFunction;

import org.springframework.lang.Nullable;

/**
* A functional interface to convert tool call results to a String
* that can be sent back to the AI model.
*/
@FunctionalInterface
public interface ToolCallResultConverter extends BiFunction<Object, Class<?>, String> {

/**
* Given an Object returned by a tool, convert it
* to a String compatible with the given class type.
*/
String apply(@Nullable Object result, Class<?> returnType);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package io.arconia.ai.tools.execution;

import io.arconia.ai.tools.definition.ToolDefinition;

/**
* An exception thrown when a tool execution fails.
*/
public class ToolExecutionException extends RuntimeException {

private final ToolDefinition toolDefinition;

public ToolExecutionException(ToolDefinition toolDefinition, Throwable cause) {
super(cause.getMessage(), cause);
this.toolDefinition = toolDefinition;
}

public ToolDefinition getToolDefinition() {
return toolDefinition;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,10 @@
* How the tool should be executed.
*/
public enum ToolExecutionMode {
BLOCKING

/**
* The tool should be executed in a blocking manner.
*/
BLOCKING;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
@NonNullApi
@NonNullFields
package io.arconia.ai.tools.execution;

import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,31 @@ public class JsonParser {
.addModules(JacksonUtils.instantiateAvailableModules())
.build();

/**
* Returns a Jackson {@link ObjectMapper} instance tailored for
* JSON-parsing operations for tool calling and structured output.
*/
public static ObjectMapper getObjectMapper() {
return OBJECT_MAPPER;
}

/**
* Converts a JSON string to a Java object.
*/
public static <T> T fromJson(String json, Class<T> type) {
Assert.notNull(json, "json cannot be null");
Assert.notNull(type, "type cannot be null");

try {
return OBJECT_MAPPER.readValue(json, type);
} catch (JsonProcessingException ex) {
throw new IllegalStateException("Conversion from JSON to %s failed".formatted(type.getName()), ex);
}
}

/**
* Converts a JSON string to a Java object.
*/
public static <T> T fromJson(String json, TypeReference<T> type) {
Assert.notNull(json, "json cannot be null");
Assert.notNull(type, "type cannot be null");
Expand All @@ -38,6 +59,9 @@ public static <T> T fromJson(String json, TypeReference<T> type) {
}
}

/**
* Converts a Java object to a JSON string.
*/
public static String toJson(@Nullable Object object) {
try {
return OBJECT_MAPPER.writeValueAsString(object);
Expand All @@ -46,7 +70,10 @@ public static String toJson(@Nullable Object object) {
}
}

// Based on the implementation in MethodInvokingFunctionCallback.
/**
* Convert a Java Object to a typed Object.
* Based on the implementation in MethodInvokingFunctionCallback.
*/
@SuppressWarnings({ "rawtypes", "unchecked" })
public static Object toTypedObject(Object value, Class<?> type) {
Assert.notNull(value, "value cannot be null");
Expand Down Expand Up @@ -74,12 +101,8 @@ public static Object toTypedObject(Object value, Class<?> type) {
return Enum.valueOf((Class<Enum>) javaType, value.toString());
}

try {
String json = OBJECT_MAPPER.writeValueAsString(value);
return OBJECT_MAPPER.readValue(json, javaType);
} catch (JsonProcessingException ex) {
throw new IllegalStateException("Conversion from Object to %s failed".formatted(type.getName()), ex);
}
String json = JsonParser.toJson(value);
return JsonParser.fromJson(json, javaType);
}

}
Loading

0 comments on commit 643fb10

Please sign in to comment.