diff --git a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/ToolCallback.java b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/ToolCallback.java index e50479a..e4aa085 100644 --- a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/ToolCallback.java +++ b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/ToolCallback.java @@ -6,17 +6,17 @@ import io.arconia.ai.tools.metadata.ToolMetadata; /** - * Specialization of {@link FunctionCallback} to identify tools in Spring AI. + * Represents a tool whose execution can be triggered by an AI model. */ public interface ToolCallback extends FunctionCallback { /** - * Definition of the tool. + * Definition used by the AI model to determine when and how to call the tool. */ ToolDefinition getToolDefinition(); /** - * Metadata for the tool. + * Metadata providing additional information on how to handle the tool. */ default ToolMetadata getToolMetadata() { return ToolMetadata.builder().build(); diff --git a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/annotation/Tool.java b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/annotation/Tool.java index 1cc726e..5af3a3b 100644 --- a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/annotation/Tool.java +++ b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/annotation/Tool.java @@ -6,6 +6,8 @@ import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import io.arconia.ai.tools.execution.DefaultToolCallResultConverter; +import io.arconia.ai.tools.execution.ToolCallResultConverter; import io.arconia.ai.tools.execution.ToolExecutionMode; /** @@ -36,4 +38,9 @@ */ boolean returnDirect() default false; + /** + * The class to use to convert the tool call result to a String. + */ + Class resultConverter() default DefaultToolCallResultConverter.class; + } diff --git a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/definition/DefaultToolDefinition.java b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/definition/DefaultToolDefinition.java index 416fa7e..576ff4a 100644 --- a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/definition/DefaultToolDefinition.java +++ b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/definition/DefaultToolDefinition.java @@ -1,29 +1,16 @@ package io.arconia.ai.tools.definition; -import java.lang.reflect.Method; - import org.springframework.util.Assert; -import io.arconia.ai.tools.json.JsonSchemaGenerator; -import io.arconia.ai.tools.util.ToolUtils; - /** * Default implementation of {@link ToolDefinition}. */ public record DefaultToolDefinition(String name, String description, String inputTypeSchema) implements ToolDefinition { public DefaultToolDefinition { - Assert.hasText(name, "name cannot be null"); - Assert.hasText(description, "description cannot be null"); - Assert.hasText(inputTypeSchema, "inputTypeSchema cannot be null"); - } - - static DefaultToolDefinition from(Method method) { - return DefaultToolDefinition.builder() - .name(ToolUtils.getToolName(method)) - .description(ToolUtils.getToolDescription(method)) - .inputTypeSchema(JsonSchemaGenerator.generate(method)) - .build(); + Assert.hasText(name, "name cannot be null or empty"); + Assert.hasText(description, "description cannot be null or empty"); + Assert.hasText(inputTypeSchema, "inputTypeSchema cannot be null or empty"); } public static Builder builder() { diff --git a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/definition/ToolDefinition.java b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/definition/ToolDefinition.java index 8faf302..f80e94d 100644 --- a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/definition/ToolDefinition.java +++ b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/definition/ToolDefinition.java @@ -2,8 +2,11 @@ import java.lang.reflect.Method; +import io.arconia.ai.tools.json.JsonSchemaGenerator; +import io.arconia.ai.tools.util.ToolUtils; + /** - * Definition of a tool that can be used by a model. + * Definition used by the AI model to determine when and how to call the tool. */ public interface ToolDefinition { @@ -13,7 +16,7 @@ public interface ToolDefinition { String name(); /** - * The tool description, used by the model to decide if and when to use the tool. + * The tool description, used by the AI model to determine what the tool does. */ String description(); @@ -30,10 +33,14 @@ static DefaultToolDefinition.Builder builder() { } /** - * Create {@link ToolDefinition} from a {@link Method}. + * Create a default {@link ToolDefinition} instance from a {@link Method}. */ static ToolDefinition from(Method method) { - return DefaultToolDefinition.from(method); + return DefaultToolDefinition.builder() + .name(ToolUtils.getToolName(method)) + .description(ToolUtils.getToolDescription(method)) + .inputTypeSchema(JsonSchemaGenerator.generateForMethodInput(method)) + .build(); } } diff --git a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/execution/DefaultToolCallResultConverter.java b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/execution/DefaultToolCallResultConverter.java new file mode 100644 index 0000000..982a8f9 --- /dev/null +++ b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/execution/DefaultToolCallResultConverter.java @@ -0,0 +1,23 @@ +package io.arconia.ai.tools.execution; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +import io.arconia.ai.tools.json.JsonParser; + +/** + * A default implementation of {@link ToolCallResultConverter}. + */ +public class DefaultToolCallResultConverter implements ToolCallResultConverter { + + @Override + public String apply(@Nullable Object result, Class returnType) { + Assert.notNull(returnType, "returnType cannot be null"); + if (returnType == Void.TYPE) { + return "Done"; + } else { + return JsonParser.toJson(result); + } + } + +} diff --git a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/execution/ToolCallResultConverter.java b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/execution/ToolCallResultConverter.java new file mode 100644 index 0000000..d064b12 --- /dev/null +++ b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/execution/ToolCallResultConverter.java @@ -0,0 +1,20 @@ +package io.arconia.ai.tools.execution; + +import java.util.function.BiFunction; + +import org.springframework.lang.Nullable; + +/** + * A functional interface to convert tool call results to a String + * that can be sent back to the AI model. + */ +@FunctionalInterface +public interface ToolCallResultConverter extends BiFunction, String> { + + /** + * Given an Object returned by a tool, convert it + * to a String compatible with the given class type. + */ + String apply(@Nullable Object result, Class returnType); + +} diff --git a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/execution/ToolExecutionException.java b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/execution/ToolExecutionException.java new file mode 100644 index 0000000..f1bde52 --- /dev/null +++ b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/execution/ToolExecutionException.java @@ -0,0 +1,21 @@ +package io.arconia.ai.tools.execution; + +import io.arconia.ai.tools.definition.ToolDefinition; + +/** + * An exception thrown when a tool execution fails. + */ +public class ToolExecutionException extends RuntimeException { + + private final ToolDefinition toolDefinition; + + public ToolExecutionException(ToolDefinition toolDefinition, Throwable cause) { + super(cause.getMessage(), cause); + this.toolDefinition = toolDefinition; + } + + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + +} diff --git a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/execution/ToolExecutionMode.java b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/execution/ToolExecutionMode.java index 5e50fc6..2f8a22a 100644 --- a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/execution/ToolExecutionMode.java +++ b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/execution/ToolExecutionMode.java @@ -4,5 +4,10 @@ * How the tool should be executed. */ public enum ToolExecutionMode { - BLOCKING + + /** + * The tool should be executed in a blocking manner. + */ + BLOCKING; + } diff --git a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/execution/package-info.java b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/execution/package-info.java new file mode 100644 index 0000000..4c40704 --- /dev/null +++ b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/execution/package-info.java @@ -0,0 +1,6 @@ +@NonNullApi +@NonNullFields +package io.arconia.ai.tools.execution; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/json/JsonParser.java b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/json/JsonParser.java index 8e7d793..1f2dbe2 100644 --- a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/json/JsonParser.java +++ b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/json/JsonParser.java @@ -23,10 +23,31 @@ public class JsonParser { .addModules(JacksonUtils.instantiateAvailableModules()) .build(); + /** + * Returns a Jackson {@link ObjectMapper} instance tailored for + * JSON-parsing operations for tool calling and structured output. + */ public static ObjectMapper getObjectMapper() { return OBJECT_MAPPER; } + /** + * Converts a JSON string to a Java object. + */ + public static T fromJson(String json, Class type) { + Assert.notNull(json, "json cannot be null"); + Assert.notNull(type, "type cannot be null"); + + try { + return OBJECT_MAPPER.readValue(json, type); + } catch (JsonProcessingException ex) { + throw new IllegalStateException("Conversion from JSON to %s failed".formatted(type.getName()), ex); + } + } + + /** + * Converts a JSON string to a Java object. + */ public static T fromJson(String json, TypeReference type) { Assert.notNull(json, "json cannot be null"); Assert.notNull(type, "type cannot be null"); @@ -38,6 +59,9 @@ public static T fromJson(String json, TypeReference type) { } } + /** + * Converts a Java object to a JSON string. + */ public static String toJson(@Nullable Object object) { try { return OBJECT_MAPPER.writeValueAsString(object); @@ -46,7 +70,10 @@ public static String toJson(@Nullable Object object) { } } - // Based on the implementation in MethodInvokingFunctionCallback. + /** + * Convert a Java Object to a typed Object. + * Based on the implementation in MethodInvokingFunctionCallback. + */ @SuppressWarnings({ "rawtypes", "unchecked" }) public static Object toTypedObject(Object value, Class type) { Assert.notNull(value, "value cannot be null"); @@ -74,12 +101,8 @@ public static Object toTypedObject(Object value, Class type) { return Enum.valueOf((Class) javaType, value.toString()); } - try { - String json = OBJECT_MAPPER.writeValueAsString(value); - return OBJECT_MAPPER.readValue(json, javaType); - } catch (JsonProcessingException ex) { - throw new IllegalStateException("Conversion from Object to %s failed".formatted(type.getName()), ex); - } + String json = JsonParser.toJson(value); + return JsonParser.fromJson(json, javaType); } } diff --git a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/json/JsonSchemaGenerator.java b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/json/JsonSchemaGenerator.java index 20ec6fb..157af16 100644 --- a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/json/JsonSchemaGenerator.java +++ b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/json/JsonSchemaGenerator.java @@ -2,74 +2,161 @@ import java.lang.reflect.Method; import java.lang.reflect.Parameter; -import java.util.concurrent.atomic.AtomicReference; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.List; import java.util.stream.Stream; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ObjectNode; import com.github.victools.jsonschema.generator.Option; import com.github.victools.jsonschema.generator.OptionPreset; import com.github.victools.jsonschema.generator.SchemaGenerator; -import com.github.victools.jsonschema.generator.SchemaGeneratorConfig; import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder; import com.github.victools.jsonschema.generator.SchemaVersion; import com.github.victools.jsonschema.module.jackson.JacksonModule; import com.github.victools.jsonschema.module.jackson.JacksonOption; import com.github.victools.jsonschema.module.swagger2.Swagger2Module; +import org.springframework.util.Assert; + /** * Utilities to generate JSON Schemas from Java entities. */ public class JsonSchemaGenerator { - private static final AtomicReference SCHEMA_GENERATOR = new AtomicReference<>(); - - public static String generate(Method method) { - var generator = buildSchemaGenerator(); + private static final SchemaGenerator TYPE_SCHEMA_GENERATOR; + private static final SchemaGenerator SUBTYPE_SCHEMA_GENERATOR; + + /* + * Initialize JSON Schema generators. + */ + static { + var schemaGeneratorConfigBuilder = new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_2020_12, OptionPreset.PLAIN_JSON) + .with(new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED)) + .with(new Swagger2Module()) + .with(Option.EXTRA_OPEN_API_FORMAT_VALUES) + .with(Option.PLAIN_DEFINITION_KEYS); + + var typeSchemaGeneratorConfig = schemaGeneratorConfigBuilder + .without(Option.SCHEMA_VERSION_INDICATOR) + .build(); + TYPE_SCHEMA_GENERATOR = new SchemaGenerator(typeSchemaGeneratorConfig); + + var subtypeSchemaGeneratorConfig = schemaGeneratorConfigBuilder.build(); + SUBTYPE_SCHEMA_GENERATOR = new SchemaGenerator(subtypeSchemaGeneratorConfig); + } + /** + * Generate a JSON Schema for a method's input parameters. + */ + public static String generateForMethodInput(Method method, SchemaOption... schemaOptions) { ObjectNode schema = JsonParser.getObjectMapper().createObjectNode(); - schema.put("$schema", SchemaVersion.DRAFT_2020_12.getIdentifier()); // Option.SCHEMA_VERSION_INDICATOR + schema.put("$schema", SchemaVersion.DRAFT_2020_12.getIdentifier()); schema.put("type", "object"); ObjectNode properties = schema.putObject("properties"); + List required = new ArrayList<>(); for (int i = 0; i < method.getParameterCount(); i++) { var parameterName = method.getParameters()[i].getName(); var parameterType = method.getGenericParameterTypes()[i]; - properties.set(parameterName, generator.generateSchema(parameterType)); + if (isMethodParameterRequired(method, i)) { + required.add(parameterName); + } + properties.set(parameterName, SUBTYPE_SCHEMA_GENERATOR.generateSchema(parameterType)); } - schema.put("additionalProperties", false); // Option.FORBIDDEN_ADDITIONAL_PROPERTIES_BY_DEFAULT - var requiredArray = schema.putArray("required"); - Stream.of(method.getParameters()).map(Parameter::getName).forEach(requiredArray::add); + if (Stream.of(schemaOptions).anyMatch(option -> option == SchemaOption.RESPECT_JSON_PROPERTY_REQUIRED)) { + required.forEach(requiredArray::add); + } else { + Stream.of(method.getParameters()).map(Parameter::getName).forEach(requiredArray::add); + } + + if (Stream.of(schemaOptions).noneMatch(option -> option == SchemaOption.ALLOW_ADDITIONAL_PROPERTIES_BY_DEFAULT)) { + schema.put("additionalProperties", false); + } + + if (Stream.of(schemaOptions).anyMatch(option -> option == SchemaOption.UPPER_CASE_TYPE_VALUES)) { + convertTypeValuesToUpperCase(schema); + } return schema.toPrettyString(); } - // Based on the implementation in ModelOptionsUtils. - private static SchemaGenerator buildSchemaGenerator() { - if (SCHEMA_GENERATOR.get() != null) { - return SCHEMA_GENERATOR.get(); + /** + * Generate a JSON Schema for a class type. + */ + public static String generateForType(Type type, SchemaOption... schemaOptions) { + Assert.notNull(type, "type cannot be null"); + ObjectNode schema = TYPE_SCHEMA_GENERATOR.generateSchema(type); + if (Stream.of(schemaOptions).noneMatch(option -> option == SchemaOption.ALLOW_ADDITIONAL_PROPERTIES_BY_DEFAULT)) { + schema.put("additionalProperties", false); } + if (Stream.of(schemaOptions).anyMatch(option -> option == SchemaOption.UPPER_CASE_TYPE_VALUES)) { + convertTypeValuesToUpperCase(schema); + } + return schema.toPrettyString(); + } + + private static boolean isMethodParameterRequired(Method method, int index) { + var jsonPropertyAnnotation = method.getParameters()[index].getAnnotation(JsonProperty.class); + if (jsonPropertyAnnotation == null) { + return false; + } + return jsonPropertyAnnotation.required(); + } + + // Based on the method in ModelOptionsUtils. + private static void convertTypeValuesToUpperCase(ObjectNode node) { + if (node.isObject()) { + node.fields().forEachRemaining(entry -> { + JsonNode value = entry.getValue(); + if (value.isObject()) { + convertTypeValuesToUpperCase((ObjectNode) value); + } else if (value.isArray()) { + value.elements().forEachRemaining(element -> { + if (element.isObject() || element.isArray()) { + convertTypeValuesToUpperCase((ObjectNode) element); + } + }); + } else if (value.isTextual() && entry.getKey().equals("type")) { + String oldValue = node.get("type").asText(); + node.put("type", oldValue.toUpperCase()); + } + }); + } else if (node.isArray()) { + node.elements().forEachRemaining(element -> { + if (element.isObject() || element.isArray()) { + convertTypeValuesToUpperCase((ObjectNode) element); + } + }); + } + } - JacksonModule jacksonModule = new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED); - Swagger2Module swaggerModule = new Swagger2Module(); + /** + * Options for generating JSON Schemas. + */ + public enum SchemaOption { - SchemaGeneratorConfig schemaGeneratorConfig = new SchemaGeneratorConfigBuilder(JsonParser.getObjectMapper(), - SchemaVersion.DRAFT_2020_12, OptionPreset.PLAIN_JSON) - .with(jacksonModule) - .with(swaggerModule) - // .with(Option.DEFINITIONS_FOR_ALL_OBJECTS) - .with(Option.EXTRA_OPEN_API_FORMAT_VALUES) - .with(Option.PLAIN_DEFINITION_KEYS) - .without(Option.SCHEMA_VERSION_INDICATOR) - .build(); + /** + * Properties are only required if marked as such via the Jackson annotation "@JsonProperty(required = true)". + * Beware, that OpenAI requires all properties to be required. + */ + RESPECT_JSON_PROPERTY_REQUIRED, - SchemaGenerator generator = new SchemaGenerator(schemaGeneratorConfig); + /** + * Allow additional properties by default. Beware, that OpenAI requires additional properties NOT to be allowed. + */ + ALLOW_ADDITIONAL_PROPERTIES_BY_DEFAULT, - SCHEMA_GENERATOR.set(generator); + /** + * Convert all "type" values to upper case. For example, it's require in OpenAPI 3.0 with Vertex AI. + */ + UPPER_CASE_TYPE_VALUES; - return generator; } } diff --git a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/metadata/DefaultToolMetadata.java b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/metadata/DefaultToolMetadata.java index e8a8357..f7cc6b0 100644 --- a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/metadata/DefaultToolMetadata.java +++ b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/metadata/DefaultToolMetadata.java @@ -1,22 +1,12 @@ package io.arconia.ai.tools.metadata; -import java.lang.reflect.Method; - import io.arconia.ai.tools.execution.ToolExecutionMode; -import io.arconia.ai.tools.util.ToolUtils; /** * Default implementation of {@link ToolMetadata}. */ public record DefaultToolMetadata(ToolExecutionMode executionMode, boolean returnDirect) implements ToolMetadata { - static DefaultToolMetadata from(Method method) { - return DefaultToolMetadata.builder() - .executionMode(ToolUtils.getToolExecutionMode(method)) - .returnDirect(ToolUtils.getToolReturnDirect(method)) - .build(); - } - public static Builder builder() { return new Builder(); } diff --git a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/metadata/ToolMetadata.java b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/metadata/ToolMetadata.java index 9fea0fd..750534a 100644 --- a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/metadata/ToolMetadata.java +++ b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/metadata/ToolMetadata.java @@ -3,6 +3,7 @@ import java.lang.reflect.Method; import io.arconia.ai.tools.execution.ToolExecutionMode; +import io.arconia.ai.tools.util.ToolUtils; /** * Metadata about a tool specification and execution. @@ -31,10 +32,13 @@ static DefaultToolMetadata.Builder builder() { } /** - * Create {@link ToolMetadata} from a {@link Method}. + * Create a default {@link ToolMetadata} instance from a {@link Method}. */ static ToolMetadata from(Method method) { - return DefaultToolMetadata.from(method); + return DefaultToolMetadata.builder() + .executionMode(ToolUtils.getToolExecutionMode(method)) + .returnDirect(ToolUtils.getToolReturnDirect(method)) + .build(); } } diff --git a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/method/MethodToolCallback.java b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/method/MethodToolCallback.java index ce0b8bc..6091a8e 100644 --- a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/method/MethodToolCallback.java +++ b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/method/MethodToolCallback.java @@ -1,5 +1,6 @@ package io.arconia.ai.tools.method; +import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.Map; @@ -7,15 +8,19 @@ import com.fasterxml.jackson.core.type.TypeReference; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.ai.chat.model.ToolContext; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.CollectionUtils; -import org.springframework.util.ReflectionUtils; import io.arconia.ai.tools.ToolCallback; import io.arconia.ai.tools.definition.ToolDefinition; +import io.arconia.ai.tools.execution.DefaultToolCallResultConverter; +import io.arconia.ai.tools.execution.ToolCallResultConverter; +import io.arconia.ai.tools.execution.ToolExecutionException; import io.arconia.ai.tools.json.JsonParser; import io.arconia.ai.tools.metadata.ToolMetadata; @@ -24,26 +29,30 @@ */ public class MethodToolCallback implements ToolCallback { + private static final Logger logger = LoggerFactory.getLogger(MethodToolCallback.class); + + private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter(); + private final ToolDefinition toolDefinition; private final ToolMetadata toolMetadata; private final Method toolMethod; - @Nullable private final Object toolObject; - public MethodToolCallback(ToolDefinition toolDefinition, ToolMetadata toolMetadata, Method toolMethod, @Nullable Object toolObject) { + private final ToolCallResultConverter toolCallResultConverter; + + public MethodToolCallback(ToolDefinition toolDefinition, ToolMetadata toolMetadata, Method toolMethod, Object toolObject, @Nullable ToolCallResultConverter toolCallResultConverter) { Assert.notNull(toolDefinition, "toolDefinition cannot be null"); Assert.notNull(toolMetadata, "toolMetadata cannot be null"); Assert.notNull(toolMethod, "toolMethod cannot be null"); - if (!Modifier.isStatic(toolMethod.getModifiers())) { - Assert.notNull(toolObject, "toolObject cannot be null for non-static method"); - } + Assert.notNull(toolObject, "toolObject cannot be null"); this.toolDefinition = toolDefinition; this.toolMetadata = toolMetadata; this.toolMethod = toolMethod; this.toolObject = toolObject; + this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter : DEFAULT_RESULT_CONVERTER; } @Override @@ -65,23 +74,37 @@ public String call(String toolInput) { public String call(String toolInput, @Nullable ToolContext toolContext) { Assert.hasText(toolInput, "toolInput cannot be null or empty"); + logger.debug("Starting execution of tool: {}", toolDefinition.name()); + validateToolContextSupport(toolContext); Map toolArguments = extractToolArguments(toolInput); Object[] methodArguments = buildMethodArguments(toolArguments, toolContext); - Object result = callMethod(methodArguments); + Object result; + try { + result = callMethod(methodArguments); + logger.debug("Successful execution of tool: {}", toolDefinition.name()); + } catch (ToolExecutionException ex) { + if (toolMetadata.returnDirect()) { + // When the tool result should be returned directly to the user instead of back to the model, + // we should rethrow the exception to be handled by the caller. + throw ex; + } + logger.error("Failed execution of tool: {}", toolDefinition.name(), ex); + return ex.getMessage(); + } Class returnType = toolMethod.getReturnType(); - return formatResult(result, returnType); + return toolCallResultConverter.apply(result, returnType); } private void validateToolContextSupport(@Nullable ToolContext toolContext) { var isToolContextRequired = toolContext != null && !CollectionUtils.isEmpty(toolContext.getContext()); - var isToolContextAcceptedByMethod = Stream.of(toolMethod.getGenericParameterTypes()) - .anyMatch(type -> ClassUtils.isAssignable(type.getClass(), ToolContext.class)); + var isToolContextAcceptedByMethod = Stream.of(toolMethod.getParameterTypes()) + .anyMatch(type -> ClassUtils.isAssignable(type, ToolContext.class)); if (isToolContextRequired && !isToolContextAcceptedByMethod) { throw new IllegalArgumentException("ToolContext is not supported by the method as an argument"); } @@ -115,7 +138,16 @@ private Object callMethod(Object[] methodArguments) { if (isObjectNotPublic() || isMethodNotPublic()) { toolMethod.setAccessible(true); } - return ReflectionUtils.invokeMethod(toolMethod, toolObject, methodArguments); + + Object result; + try { + result = toolMethod.invoke(toolObject, methodArguments); + } catch (IllegalAccessException ex) { + throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex); + } catch (InvocationTargetException ex) { + throw new ToolExecutionException(toolDefinition, ex.getCause()); + } + return result; } private boolean isObjectNotPublic() { @@ -126,17 +158,6 @@ private boolean isMethodNotPublic() { return !Modifier.isPublic(toolMethod.getModifiers()); } - // Based on the implementation in MethodInvokingFunctionCallback. - private String formatResult(@Nullable Object result, Class returnType) { - if (returnType == Void.TYPE) { - return "Done"; - } else if (returnType == String.class) { - return result != null ? (String) result : ""; - } else { - return JsonParser.toJson(result); - } - } - public static Builder builder() { return new Builder(); } @@ -151,6 +172,8 @@ public static class Builder { private Object toolObject; + private ToolCallResultConverter toolCallResultConverter; + private Builder() {} public Builder toolDefinition(ToolDefinition toolDefinition) { @@ -173,8 +196,13 @@ public Builder toolObject(Object toolObject) { return this; } + public Builder toolCallResultConverter(ToolCallResultConverter toolCallResultConverter) { + this.toolCallResultConverter = toolCallResultConverter; + return this; + } + public MethodToolCallback build() { - return new MethodToolCallback(toolDefinition, toolMetadata, toolMethod, toolObject); + return new MethodToolCallback(toolDefinition, toolMetadata, toolMethod, toolObject, toolCallResultConverter); } } diff --git a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/method/MethodToolCallbackProvider.java b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/method/MethodToolCallbackProvider.java index d50544f..34ad7fb 100644 --- a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/method/MethodToolCallbackProvider.java +++ b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/method/MethodToolCallbackProvider.java @@ -49,6 +49,7 @@ public ToolCallback[] getToolCallbacks() { .toolMetadata(ToolMetadata.from(toolMethod)) .toolMethod(toolMethod) .toolObject(toolObject) + .toolCallResultConverter(ToolUtils.getToolCallResultConverter(toolMethod)) .build()) .toArray(ToolCallback[]::new)) .flatMap(Stream::of) @@ -59,7 +60,7 @@ public ToolCallback[] getToolCallbacks() { return toolCallbacks; } - private static boolean isFunctionalType(Method toolMethod) { + private boolean isFunctionalType(Method toolMethod) { var isFunction = ClassUtils.isAssignable(toolMethod.getReturnType(), Function.class) || ClassUtils.isAssignable(toolMethod.getReturnType(), Supplier.class) || ClassUtils.isAssignable(toolMethod.getReturnType(), Consumer.class); diff --git a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/util/ToolUtils.java b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/util/ToolUtils.java index 9d36114..390644b 100644 --- a/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/util/ToolUtils.java +++ b/arconia-ai/arconia-ai-tools/src/main/java/io/arconia/ai/tools/util/ToolUtils.java @@ -11,6 +11,8 @@ import org.springframework.util.StringUtils; import io.arconia.ai.tools.annotation.Tool; +import io.arconia.ai.tools.execution.DefaultToolCallResultConverter; +import io.arconia.ai.tools.execution.ToolCallResultConverter; import io.arconia.ai.tools.execution.ToolExecutionMode; /** @@ -44,6 +46,19 @@ public static boolean getToolReturnDirect(Method method) { return tool != null && tool.returnDirect(); } + public static ToolCallResultConverter getToolCallResultConverter(Method method) { + var tool = method.getAnnotation(Tool.class); + if (tool == null) { + return new DefaultToolCallResultConverter(); + } + var type = tool.resultConverter(); + try { + return type.getDeclaredConstructor().newInstance(); + } catch (Exception e) { + throw new IllegalArgumentException("Failed to instantiate ToolCallResultConverter: " + type, e); + } + } + public static List getDuplicateToolNames(FunctionCallback... functionCallbacks) { return Stream.of(functionCallbacks) .collect(Collectors.groupingBy(FunctionCallback::getName, Collectors.counting())) diff --git a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/core/tools/ToolUtilsTests.java b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/core/tools/ToolUtilsTests.java deleted file mode 100644 index 95da5e6..0000000 --- a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/core/tools/ToolUtilsTests.java +++ /dev/null @@ -1,65 +0,0 @@ -package io.arconia.ai.core.tools; - -import java.util.List; - -import io.arconia.ai.tools.ToolCallback; -import org.junit.jupiter.api.Test; - -import io.arconia.ai.tools.definition.ToolDefinition; -import io.arconia.ai.tools.util.ToolUtils; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * Unit tests for {@link ToolUtils}. - */ -class ToolUtilsTests { - - @Test - void shouldDetectDuplicateToolNames() { - ToolCallback callback1 = new TestToolCallback("tool_a"); - ToolCallback callback2 = new TestToolCallback("tool_a"); - ToolCallback callback3 = new TestToolCallback("tool_b"); - - List duplicates = ToolUtils.getDuplicateToolNames(callback1, callback2, callback3); - - assertThat(duplicates).isNotEmpty(); - assertThat(duplicates).contains("tool_a"); - } - - @Test - void shouldNotDetectDuplicateToolNames() { - ToolCallback callback1 = new TestToolCallback("tool_a"); - ToolCallback callback2 = new TestToolCallback("tool_b"); - ToolCallback callback3 = new TestToolCallback("tool_c"); - - List duplicates = ToolUtils.getDuplicateToolNames(callback1, callback2, callback3); - - assertThat(duplicates).isEmpty(); - } - - static class TestToolCallback implements ToolCallback { - - private final ToolDefinition toolDefinition; - - public TestToolCallback(String name) { - this.toolDefinition = ToolDefinition.builder() - .name(name) - .description(name) - .inputTypeSchema("{}") - .build(); - } - - @Override - public ToolDefinition getToolDefinition() { - return toolDefinition; - } - - @Override - public String call(String functionInput) { - return ""; - } - - } - -} diff --git a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/core/tools/json/JsonSchemaGeneratorTests.java b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/core/tools/json/JsonSchemaGeneratorTests.java deleted file mode 100644 index 556a746..0000000 --- a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/core/tools/json/JsonSchemaGeneratorTests.java +++ /dev/null @@ -1,42 +0,0 @@ -package io.arconia.ai.tools.json; - -import java.lang.reflect.Method; -import java.util.Arrays; -import java.util.List; - -import org.springframework.util.ReflectionUtils; - -class JsonSchemaGeneratorTests { - - private static Method getMethod(String name) { - return Arrays.stream(ReflectionUtils.getDeclaredMethods(TestClass.class)) - .filter(m -> m.getName().equals(name)) - .findFirst() - .orElseThrow(); - } - - public static class TestClass { - - public static String staticMethodName(String arg1, Integer arg2) { - return arg1 + arg2; - } - - public String methodName(String arg1, Integer arg2) { - return arg1 + arg2; - } - - public String noArgsMethod() { - return "Hello"; - } - - public String oneArgMethod(String greeting) { - return greeting; - } - - public String oneArgMethodList(List greetings) { - return String.join(", ", greetings); - } - - } - -} diff --git a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/core/tools/method/MethodToolCallbackProviderTests.java b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/core/tools/method/MethodToolCallbackProviderTests.java deleted file mode 100644 index dd58523..0000000 --- a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/core/tools/method/MethodToolCallbackProviderTests.java +++ /dev/null @@ -1,89 +0,0 @@ -package io.arconia.ai.tools.method; - -import java.util.List; -import java.util.function.Function; -import java.util.stream.Stream; - -import org.junit.jupiter.api.Test; - -import io.arconia.ai.tools.ToolCallback; -import io.arconia.ai.tools.annotation.Tool; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Unit tests for {@link MethodToolCallbackProvider}. - */ -class MethodToolCallbackProviderTests { - - @Test - void shouldProvideToolCallbacksFromObject() { - Tools tools = new Tools(); - MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder().toolObjects(tools).build(); - - ToolCallback[] callbacks = provider.getToolCallbacks(); - - assertThat(callbacks).hasSize(2); - - var callback1 = Stream.of(callbacks).filter(c -> c.getName().equals("testMethod")).findFirst(); - assertThat(callback1).isPresent(); - assertThat(callback1.get().getName()).isEqualTo("testMethod"); - assertThat(callback1.get().getDescription()).isEqualTo("Test description"); - - var callback2 = Stream.of(callbacks).filter(c -> c.getName().equals("testStaticMethod")).findFirst(); - assertThat(callback2).isPresent(); - assertThat(callback2.get().getName()).isEqualTo("testStaticMethod"); - assertThat(callback2.get().getDescription()).isEqualTo("Test description"); - } - - @Test - void shouldEnsureUniqueToolNames() { - ToolsWithDuplicates testComponent = new ToolsWithDuplicates(); - MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder().toolObjects(testComponent).build(); - - assertThatThrownBy(provider::getToolCallbacks).isInstanceOf(IllegalStateException.class) - .hasMessageContaining("Multiple tools with the same name (testMethod) found in sources: " - + testComponent.getClass().getName()); - } - - static class Tools { - - @Tool("Test description") - static List testStaticMethod(String input) { - return List.of(input); - } - - @Tool("Test description") - List testMethod(String input) { - return List.of(input); - } - - @Tool("Test description") - Function testFunction(String input) { - // This method should be ignored as it's a functional type, which is not - // supported. - return String::length; - } - - void nonToolMethod() { - // This method should be ignored as it doesn't have @Tool annotation - } - - } - - static class ToolsWithDuplicates { - - @Tool(name = "testMethod", value = "Test description") - List testMethod1(String input) { - return List.of(input); - } - - @Tool(name = "testMethod", value = "Test description") - List testMethod2(String input) { - return List.of(input); - } - - } - -} diff --git a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/core/tools/method/MethodToolCallbackTests.java b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/core/tools/method/MethodToolCallbackTests.java deleted file mode 100644 index 463fdfa..0000000 --- a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/core/tools/method/MethodToolCallbackTests.java +++ /dev/null @@ -1,193 +0,0 @@ -package io.arconia.ai.tools.method; - -import java.lang.reflect.Method; -import java.util.Arrays; -import java.util.List; - -import com.fasterxml.jackson.core.type.TypeReference; - -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.springframework.util.ReflectionUtils; - -import io.arconia.ai.tools.annotation.Tool; -import io.arconia.ai.tools.definition.ToolDefinition; -import io.arconia.ai.tools.json.JsonParser; -import io.arconia.ai.tools.metadata.ToolMetadata; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * Unit tests for {@link MethodToolCallback}. - */ -class MethodToolCallbackTests { - - @ParameterizedTest - @ValueSource(strings = { - "publicStaticMethod", - "privateStaticMethod", - "packageStaticMethod", - "publicMethod", - "privateMethod", - "packageMethod" - }) - void shouldCallToolFromPublicClass(String methodName) { - validateAssertions(methodName, new PublicTools()); - } - - @ParameterizedTest - @ValueSource(strings = { - "publicStaticMethod", - "privateStaticMethod", - "packageStaticMethod", - "publicMethod", - "privateMethod", - "packageMethod" - }) - void shouldCallToolFromPrivateClass(String methodName) { - validateAssertions(methodName, new PrivateTools()); - } - - @ParameterizedTest - @ValueSource(strings = { - "publicStaticMethod", - "privateStaticMethod", - "packageStaticMethod", - "publicMethod", - "privateMethod", - "packageMethod" - }) - void shouldCallToolFromPackageClass(String methodName) { - validateAssertions(methodName, new PackageTools()); - } - - private static void validateAssertions(String methodName, Object toolObject) { - Method toolMethod = getMethod(methodName, toolObject.getClass()); - assertThat(toolMethod).isNotNull(); - MethodToolCallback callback = MethodToolCallback.builder() - .toolDefinition(ToolDefinition.from(toolMethod)) - .toolMetadata(ToolMetadata.from(toolMethod)) - .toolMethod(toolMethod) - .toolObject(toolObject) - .build(); - - String result = callback.call(""" - { - "input": "Wingardium Leviosa" - } - """); - - assertThat(JsonParser.fromJson(result, new TypeReference>() {})) - .contains("Wingardium Leviosa"); - } - - private static Method getMethod(String name, Class toolsClass) { - return Arrays.stream(ReflectionUtils.getDeclaredMethods(toolsClass)) - .filter(m -> m.getName().equals(name)) - .findFirst() - .orElseThrow(); - } - - static public class PublicTools { - - @Tool("Test description") - public static List publicStaticMethod(String input) { - return List.of(input); - } - - @Tool("Test description") - private static List privateStaticMethod(String input) { - return List.of(input); - } - - @Tool("Test description") - static List packageStaticMethod(String input) { - return List.of(input); - } - - @Tool("Test description") - public List publicMethod(String input) { - return List.of(input); - } - - @Tool("Test description") - private List privateMethod(String input) { - return List.of(input); - } - - @Tool("Test description") - List packageMethod(String input) { - return List.of(input); - } - - } - - static private class PrivateTools { - - @Tool("Test description") - public static List publicStaticMethod(String input) { - return List.of(input); - } - - @Tool("Test description") - private static List privateStaticMethod(String input) { - return List.of(input); - } - - @Tool("Test description") - static List packageStaticMethod(String input) { - return List.of(input); - } - - @Tool("Test description") - public List publicMethod(String input) { - return List.of(input); - } - - @Tool("Test description") - private List privateMethod(String input) { - return List.of(input); - } - - @Tool("Test description") - List packageMethod(String input) { - return List.of(input); - } - - } - - static class PackageTools { - - @Tool("Test description") - public static List publicStaticMethod(String input) { - return List.of(input); - } - - @Tool("Test description") - private static List privateStaticMethod(String input) { - return List.of(input); - } - - @Tool("Test description") - static List packageStaticMethod(String input) { - return List.of(input); - } - - @Tool("Test description") - public List publicMethod(String input) { - return List.of(input); - } - - @Tool("Test description") - private List privateMethod(String input) { - return List.of(input); - } - - @Tool("Test description") - List packageMethod(String input) { - return List.of(input); - } - - } - -} diff --git a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/ToolCallbackTests.java b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/ToolCallbackTests.java new file mode 100644 index 0000000..254110c --- /dev/null +++ b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/ToolCallbackTests.java @@ -0,0 +1,45 @@ +package io.arconia.ai.tools; + +import org.junit.jupiter.api.Test; + +import io.arconia.ai.tools.definition.ToolDefinition; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link ToolCallback}. + */ +class ToolCallbackTests { + + @Test + void shouldOnlyImplementRequiredMethods() { + var testToolCallback = new TestToolCallback("test"); + assertThat(testToolCallback.getToolDefinition()).isNotNull(); + assertThat(testToolCallback.getToolMetadata()).isNotNull(); + } + + static class TestToolCallback implements ToolCallback { + + private final ToolDefinition toolDefinition; + + public TestToolCallback(String name) { + this.toolDefinition = ToolDefinition.builder() + .name(name) + .description(name) + .inputTypeSchema("{}") + .build(); + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String toolInput) { + return ""; + } + + } + +} diff --git a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/definition/DefaultToolDefinitionTests.java b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/definition/DefaultToolDefinitionTests.java new file mode 100644 index 0000000..dc51492 --- /dev/null +++ b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/definition/DefaultToolDefinitionTests.java @@ -0,0 +1,63 @@ +package io.arconia.ai.tools.definition; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link DefaultToolDefinition}. + */ +class DefaultToolDefinitionTests { + + @Test + void shouldCreateDefaultToolDefinition() { + var toolDefinition = new DefaultToolDefinition("name", "description", "{}"); + assertThat(toolDefinition.name()).isEqualTo("name"); + assertThat(toolDefinition.description()).isEqualTo("description"); + assertThat(toolDefinition.inputTypeSchema()).isEqualTo("{}"); + } + + @Test + void shouldThrowExceptionWhenNameIsNull() { + assertThatThrownBy(() -> new DefaultToolDefinition(null, "description", "{}")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("name cannot be null or empty"); + } + + @Test + void shouldThrowExceptionWhenNameIsEmpty() { + assertThatThrownBy(() -> new DefaultToolDefinition("", "description", "{}")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("name cannot be null or empty"); + } + + @Test + void shouldThrowExceptionWhenDescriptionIsNull() { + assertThatThrownBy(() -> new DefaultToolDefinition("name", null, "{}")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("description cannot be null or empty"); + } + + @Test + void shouldThrowExceptionWhenDescriptionIsEmpty() { + assertThatThrownBy(() -> new DefaultToolDefinition("name", "", "{}")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("description cannot be null or empty"); + } + + @Test + void shouldThrowExceptionWhenInputTypeSchemaIsNull() { + assertThatThrownBy(() -> new DefaultToolDefinition("name", "description", null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("inputTypeSchema cannot be null or empty"); + } + + @Test + void shouldThrowExceptionWhenInputTypeSchemaIsEmpty() { + assertThatThrownBy(() -> new DefaultToolDefinition("name", "description", "")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("inputTypeSchema cannot be null or empty"); + } + +} diff --git a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/definition/ToolDefinitionTests.java b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/definition/ToolDefinitionTests.java new file mode 100644 index 0000000..fb9795f --- /dev/null +++ b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/definition/ToolDefinitionTests.java @@ -0,0 +1,57 @@ +package io.arconia.ai.tools.definition; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import io.arconia.ai.tools.annotation.Tool; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link ToolDefinition}. + */ +class ToolDefinitionTests { + + @Test + void shouldCreateDefaultToolDefinitionBuilder() { + var toolDefinition = ToolDefinition.builder() + .name("name") + .description("description") + .inputTypeSchema("{}") + .build(); + assertThat(toolDefinition.name()).isEqualTo("name"); + assertThat(toolDefinition.description()).isEqualTo("description"); + assertThat(toolDefinition.inputTypeSchema()).isEqualTo("{}"); + } + + @Test + void shouldCreateToolDefinitionFromMethod() { + var toolDefinition = ToolDefinition.from(Tools.class.getDeclaredMethods()[0]); + assertThat(toolDefinition.name()).isEqualTo("mySuperTool"); + assertThat(toolDefinition.description()).isEqualTo("Test description"); + assertThat(toolDefinition.inputTypeSchema()).isEqualToIgnoringWhitespace(""" + { + "$schema" : "https://json-schema.org/draft/2020-12/schema", + "type" : "object", + "properties" : { + "input" : { + "type" : "string" + } + }, + "required" : [ "input" ], + "additionalProperties" : false + } + """); + } + + static class Tools { + + @Tool("Test description") + public List mySuperTool(String input) { + return List.of(input); + } + + } + +} diff --git a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/execution/DefaultToolCallResultConverterTests.java b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/execution/DefaultToolCallResultConverterTests.java new file mode 100644 index 0000000..5055c6e --- /dev/null +++ b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/execution/DefaultToolCallResultConverterTests.java @@ -0,0 +1,95 @@ +package io.arconia.ai.tools.execution; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link DefaultToolCallResultConverter}. + */ +class DefaultToolCallResultConverterTests { + + private final DefaultToolCallResultConverter converter = new DefaultToolCallResultConverter(); + + @Test + void convertWithNullReturnTypeShouldThrowException() { + assertThatThrownBy(() -> converter.apply(null, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("returnType cannot be null"); + } + + @Test + void convertVoidReturnTypeShouldReturnDone() { + String result = converter.apply(null, void.class); + assertThat(result).isEqualTo("Done"); + } + + @Test + void convertStringReturnTypeShouldReturnJson() { + String result = converter.apply("test", String.class); + assertThat(result).isEqualTo("\"test\""); + } + + @Test + void convertNullReturnValueShouldReturnNullJson() { + String result = converter.apply(null, String.class); + assertThat(result).isEqualTo("null"); + } + + @Test + void convertObjectReturnTypeShouldReturnJson() { + TestObject testObject = new TestObject("test", 42); + String result = converter.apply(testObject, TestObject.class); + assertThat(result) + .containsIgnoringWhitespaces(""" + "name": "test" + """) + .containsIgnoringWhitespaces(""" + "value": 42 + """); + } + + @Test + void convertCollectionReturnTypeShouldReturnJson() { + List testList = List.of("one", "two", "three"); + String result = converter.apply(testList, List.class); + assertThat(result).isEqualTo(""" + ["one","two","three"] + """.trim()); + } + + @Test + void convertMapReturnTypeShouldReturnJson() { + Map testMap = Map.of("one", 1, "two", 2); + String result = converter.apply(testMap, Map.class); + assertThat(result) + .containsIgnoringWhitespaces(""" + "one": 1 + """) + .containsIgnoringWhitespaces(""" + "two": 2 + """); + } + + static class TestObject { + private final String name; + private final int value; + + TestObject(String name, int value) { + this.name = name; + this.value = value; + } + + public String getName() { + return name; + } + + public int getValue() { + return value; + } + } +} diff --git a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/execution/ToolExecutionExceptionTests.java b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/execution/ToolExecutionExceptionTests.java new file mode 100644 index 0000000..e733e49 --- /dev/null +++ b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/execution/ToolExecutionExceptionTests.java @@ -0,0 +1,34 @@ +package io.arconia.ai.tools.execution; + +import io.arconia.ai.tools.definition.ToolDefinition; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link ToolExecutionException}. + */ +class ToolExecutionExceptionTests { + + @Test + void constructorShouldSetCauseAndMessage() { + String errorMessage = "Test error message"; + RuntimeException cause = new RuntimeException(errorMessage); + + ToolExecutionException exception = new ToolExecutionException(mock(ToolDefinition.class), cause); + + assertThat(exception.getCause()).isEqualTo(cause); + assertThat(exception.getMessage()).isEqualTo(errorMessage); + } + + @Test + void getToolDefinitionShouldReturnToolDefinition() { + RuntimeException cause = new RuntimeException("Test error"); + ToolDefinition toolDefinition = mock(ToolDefinition.class); + ToolExecutionException exception = new ToolExecutionException(toolDefinition, cause); + + assertThat(exception.getToolDefinition()).isEqualTo(toolDefinition); + } + +} diff --git a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/json/JsonParserTests.java b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/json/JsonParserTests.java new file mode 100644 index 0000000..90aa9b4 --- /dev/null +++ b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/json/JsonParserTests.java @@ -0,0 +1,217 @@ +package io.arconia.ai.tools.json; + +import com.fasterxml.jackson.core.type.TypeReference; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for the {@link JsonParser} class. + */ +class JsonParserTests { + + @Test + void shouldGetObjectMapper() { + var objectMapper = JsonParser.getObjectMapper(); + assertThat(objectMapper).isNotNull(); + } + + @Test + void shouldThrowExceptionWhenJsonIsNull() { + assertThatThrownBy(() -> JsonParser.fromJson(null, TestRecord.class)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("json cannot be null"); + } + + @Test + void shouldThrowExceptionWhenClassIsNull() { + assertThatThrownBy(() -> JsonParser.fromJson("{}", (Class) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("type cannot be null"); + } + + @Test + void shouldThrowExceptionWhenTypeIsNull() { + assertThatThrownBy(() -> JsonParser.fromJson("{}", (TypeReference) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("type cannot be null"); + } + + @Test + void fromJsonToObject() { + var json = """ + { + "name" : "John", + "age" : 30 + } + """; + var object = JsonParser.fromJson(json, TestRecord.class); + assertThat(object).isNotNull(); + assertThat(object.name).isEqualTo("John"); + assertThat(object.age).isEqualTo(30); + } + + @Test + void fromJsonToObjectWithMissingProperty() { + var json = """ + { + "name": "John" + } + """; + var object = JsonParser.fromJson(json, TestRecord.class); + assertThat(object).isNotNull(); + assertThat(object.name).isEqualTo("John"); + assertThat(object.age).isNull(); + } + + @Test + void fromJsonToObjectWithNullProperty() { + var json = """ + { + "name": "John", + "age": null + } + """; + var object = JsonParser.fromJson(json, TestRecord.class); + assertThat(object).isNotNull(); + assertThat(object.name).isEqualTo("John"); + assertThat(object.age).isNull(); + } + + @Test + void fromJsonToObjectWithOtherNullProperty() { + var json = """ + { + "name": null, + "age": 21 + } + """; + var object = JsonParser.fromJson(json, TestRecord.class); + assertThat(object).isNotNull(); + assertThat(object.name).isNull(); + assertThat(object.age).isEqualTo(21); + } + + @Test + void fromJsonToObjectWithUnknownProperty() { + var json = """ + { + "name": "James", + "surname": "Bond" + } + """; + var object = JsonParser.fromJson(json, TestRecord.class); + assertThat(object).isNotNull(); + assertThat(object.name).isEqualTo("James"); + } + + @Test + void fromObjectToJson() { + var object = new TestRecord("John", 30); + var json = JsonParser.toJson(object); + assertThat(json).isEqualToIgnoringWhitespace(""" + { + "name" : "John", + "age" : 30 + } + """); + } + + @Test + void fromObjectToJsonWithNullValues() { + var object = new TestRecord("John", null); + var json = JsonParser.toJson(object); + assertThat(json).isEqualToIgnoringWhitespace(""" + { + "name" : "John", + "age" : null + } + """); + } + + @Test + void fromNullObjectToJson() { + var json = JsonParser.toJson(null); + assertThat(json).isEqualToIgnoringWhitespace("null"); + } + + @Test + void fromObjectToString() { + var value = JsonParser.toTypedObject("John", String.class); + assertThat(value).isOfAnyClassIn(String.class); + assertThat(value).isEqualTo("John"); + } + + @Test + void fromObjectToByte() { + var value = JsonParser.toTypedObject("1", Byte.class); + assertThat(value).isOfAnyClassIn(Byte.class); + assertThat(value).isEqualTo((byte) 1); + } + + @Test + void fromObjectToInteger() { + var value = JsonParser.toTypedObject("1", Integer.class); + assertThat(value).isOfAnyClassIn(Integer.class); + assertThat(value).isEqualTo(1); + } + + @Test + void fromObjectToShort() { + var value = JsonParser.toTypedObject("1", Short.class); + assertThat(value).isOfAnyClassIn(Short.class); + assertThat(value).isEqualTo((short) 1); + } + + @Test + void fromObjectToLong() { + var value = JsonParser.toTypedObject("1", Long.class); + assertThat(value).isOfAnyClassIn(Long.class); + assertThat(value).isEqualTo(1L); + } + + @Test + void fromObjectToDouble() { + var value = JsonParser.toTypedObject("1.0", Double.class); + assertThat(value).isOfAnyClassIn(Double.class); + assertThat(value).isEqualTo(1.0); + } + + @Test + void fromObjectToFloat() { + var value = JsonParser.toTypedObject("1.0", Float.class); + assertThat(value).isOfAnyClassIn(Float.class); + assertThat(value).isEqualTo(1.0f); + } + + @Test + void fromObjectToBoolean() { + var value = JsonParser.toTypedObject("true", Boolean.class); + assertThat(value).isOfAnyClassIn(Boolean.class); + assertThat(value).isEqualTo(true); + } + + @Test + void fromObjectToEnum() { + var value = JsonParser.toTypedObject("VALUE", TestEnum.class); + assertThat(value).isOfAnyClassIn(TestEnum.class); + assertThat(value).isEqualTo(TestEnum.VALUE); + } + + @Test + void fromObjectToRecord() { + var record = new TestRecord("John", 30); + var value = JsonParser.toTypedObject(record, TestRecord.class); + assertThat(value).isOfAnyClassIn(TestRecord.class); + assertThat(value).isEqualTo(new TestRecord("John", 30)); + } + + record TestRecord(String name, Integer age) {} + + enum TestEnum { + VALUE + } + +} diff --git a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/json/JsonSchemaGeneratorTests.java b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/json/JsonSchemaGeneratorTests.java new file mode 100644 index 0000000..4a80206 --- /dev/null +++ b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/json/JsonSchemaGeneratorTests.java @@ -0,0 +1,355 @@ +package io.arconia.ai.tools.json; + +import java.lang.reflect.Method; +import java.time.Duration; +import java.time.Instant; +import java.time.LocalDateTime; +import java.time.Month; +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link JsonSchemaGenerator}. + */ +class JsonSchemaGeneratorTests { + + @Test + void generateSchemaForMethodWithSimpleParameters() throws Exception { + Method method = TestMethods.class.getDeclaredMethod("simpleMethod", String.class, int.class); + + String schema = JsonSchemaGenerator.generateForMethodInput(method); + String expectedJsonSchema = """ + { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer", + "format" : "int32" + } + }, + "required": [ + "name", + "age" + ], + "additionalProperties": false + } + """; + + assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); + } + + @Test + void generateSchemaForMethodWithJsonPropertyAnnotations() throws Exception { + Method method = TestMethods.class.getDeclaredMethod("annotatedMethod", String.class, String.class); + + String schema = JsonSchemaGenerator.generateForMethodInput(method, JsonSchemaGenerator.SchemaOption.RESPECT_JSON_PROPERTY_REQUIRED); + String expectedJsonSchema = """ + { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "username": { + "type": "string" + }, + "password": { + "type": "string" + } + }, + "required": [ + "password" + ], + "additionalProperties": false + } + """; + + assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); + } + + @Test + void generateSchemaForMethodWithAdditionalPropertiesAllowed() throws Exception { + Method method = TestMethods.class.getDeclaredMethod("simpleMethod", String.class, int.class); + + String schema = JsonSchemaGenerator.generateForMethodInput(method, JsonSchemaGenerator.SchemaOption.ALLOW_ADDITIONAL_PROPERTIES_BY_DEFAULT); + + JsonNode jsonNode = JsonParser.getObjectMapper().readTree(schema); + assertThat(jsonNode.has("additionalProperties")).isFalse(); + } + + @Test + void generateSchemaForMethodWithUpperCaseTypes() throws Exception { + Method method = TestMethods.class.getDeclaredMethod("simpleMethod", String.class, int.class); + + String schema = JsonSchemaGenerator.generateForMethodInput(method, JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES); + String expectedJsonSchema = """ + { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "OBJECT", + "properties": { + "name": { + "type": "STRING" + }, + "age": { + "type": "INTEGER", + "format" : "int32" + } + }, + "required": [ + "name", + "age" + ], + "additionalProperties": false + } + """; + + assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); + } + + @Test + void generateSchemaForMethodWithComplexParameters() throws Exception { + Method method = TestMethods.class.getDeclaredMethod("complexMethod", List.class, TestData.class); + + String schema = JsonSchemaGenerator.generateForMethodInput(method); + + String expectedJsonSchema = """ + { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "string" + } + }, + "data": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "format" : "int32" + }, + "name": { + "type": "string" + } + } + } + }, + "required": [ + "items", + "data" + ], + "additionalProperties": false + } + """; + + assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); + } + + @Test + void generateSchemaForMethodWithTimeParameters() throws Exception { + Method method = TestMethods.class.getDeclaredMethod("timeMethod", Duration.class, LocalDateTime.class, Instant.class); + + String schema = JsonSchemaGenerator.generateForMethodInput(method); + String expectedJsonSchema = """ + { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "duration": { + "type": "string", + "format" : "duration" + }, + "localDateTime": { + "type": "string", + "format": "date-time" + }, + "instant": { + "type": "string", + "format": "date-time" + } + }, + "required": [ + "duration", + "localDateTime", + "instant" + ], + "additionalProperties": false + } + """; + + assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); + } + + @Test + void generateSchemaForSimpleType() { + String schema = JsonSchemaGenerator.generateForType(Person.class); + String expectedJsonSchema = """ + { + "type": "object", + "properties": { + "email": { + "type": "string" + }, + "id": { + "type": "integer", + "format" : "int32" + }, + "name": { + "type": "string" + } + }, + "additionalProperties": false + } + """; + + assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); + } + + @Test + void generateSchemaForTypeWithAdditionalPropertiesAllowed() throws JsonProcessingException { + String schema = JsonSchemaGenerator.generateForType(Person.class, + JsonSchemaGenerator.SchemaOption.ALLOW_ADDITIONAL_PROPERTIES_BY_DEFAULT); + + JsonNode jsonNode = JsonParser.getObjectMapper().readTree(schema); + assertThat(jsonNode.has("additionalProperties")).isFalse(); + } + + @Test + void generateSchemaForTypeWithUpperCaseValues() { + String schema = JsonSchemaGenerator.generateForType(Person.class, + JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES); + String expectedJsonSchema = """ + { + "type": "OBJECT", + "properties": { + "email": { + "type": "STRING" + }, + "id": { + "type": "INTEGER", + "format" : "int32" + }, + "name": { + "type": "STRING" + } + }, + "additionalProperties": false + } + """; + + assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); + } + + @Test + void generateSchemaForRecord() { + String schema = JsonSchemaGenerator.generateForType(TestData.class); + String expectedJsonSchema = """ + { + "type": "object", + "properties": { + "id": { + "type": "integer", + "format" : "int32" + }, + "name": { + "type": "string" + } + }, + "additionalProperties": false + } + """; + + assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); + } + + @Test + void generateSchemaForEnum() { + String schema = JsonSchemaGenerator.generateForType(Month.class); + String expectedJsonSchema = """ + { + "type": "string", + "enum": [ + "JANUARY", + "FEBRUARY", + "MARCH", + "APRIL", + "MAY", + "JUNE", + "JULY", + "AUGUST", + "SEPTEMBER", + "OCTOBER", + "NOVEMBER", + "DECEMBER" + ], + "additionalProperties": false + } + """; + + assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); + } + + @Test + void throwExceptionWhenTypeIsNull() { + assertThatThrownBy(() -> JsonSchemaGenerator.generateForType(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("type cannot be null"); + } + + static class TestMethods { + + public void simpleMethod(String name, int age) {} + + public void annotatedMethod(String username, @JsonProperty(required = true) String password) {} + + public void complexMethod(List items, TestData data) {} + + public void timeMethod(Duration duration, LocalDateTime localDateTime, Instant instant) {} + + } + + record TestData(int id, String name) {} + + static class Person { + private int id; + private String name; + private String email; + + public int getId() { + return id; + } + + public void setId(int id) { + this.id = id; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getEmail() { + return email; + } + + public void setEmail(String email) { + this.email = email; + } + } + +} diff --git a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/metadata/DefaultToolMetadataTests.java b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/metadata/DefaultToolMetadataTests.java new file mode 100644 index 0000000..a240c5f --- /dev/null +++ b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/metadata/DefaultToolMetadataTests.java @@ -0,0 +1,31 @@ +package io.arconia.ai.tools.metadata; + +import org.junit.jupiter.api.Test; + +import io.arconia.ai.tools.execution.ToolExecutionMode; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link DefaultToolMetadata}. + */ +class DefaultToolMetadataTests { + + @Test + void shouldCreateDefaultToolMetadataWithDefaultValues() { + var toolMetadata = DefaultToolMetadata.builder().build(); + assertThat(toolMetadata.executionMode()).isEqualTo(ToolExecutionMode.BLOCKING); + assertThat(toolMetadata.returnDirect()).isFalse(); + } + + @Test + void shouldCreateDefaultToolMetadataWithGivenValues() { + var toolMetadata = DefaultToolMetadata.builder() + .executionMode(ToolExecutionMode.BLOCKING) + .returnDirect(true) + .build(); + assertThat(toolMetadata.executionMode()).isEqualTo(ToolExecutionMode.BLOCKING); + assertThat(toolMetadata.returnDirect()).isTrue(); + } + +} diff --git a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/metadata/ToolMetadataTests.java b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/metadata/ToolMetadataTests.java new file mode 100644 index 0000000..7a4c6a0 --- /dev/null +++ b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/metadata/ToolMetadataTests.java @@ -0,0 +1,40 @@ +package io.arconia.ai.tools.metadata; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import io.arconia.ai.tools.annotation.Tool; +import io.arconia.ai.tools.execution.ToolExecutionMode; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link ToolMetadata}. + */ +class ToolMetadataTests { + + @Test + void shouldCreateDefaultToolMetadataBuilder() { + var toolMetadata = ToolMetadata.builder().build(); + assertThat(toolMetadata.executionMode()).isEqualTo(ToolExecutionMode.BLOCKING); + assertThat(toolMetadata.returnDirect()).isFalse(); + } + + @Test + void shouldCreateToolMetadataFromMethod() { + var toolMetadata = ToolMetadata.from(Tools.class.getDeclaredMethods()[0]); + assertThat(toolMetadata.executionMode()).isEqualTo(ToolExecutionMode.BLOCKING); + assertThat(toolMetadata.returnDirect()).isTrue(); + } + + static class Tools { + + @Tool(value = "Test description", returnDirect = true) + public List mySuperTool(String input) { + return List.of(input); + } + + } + +} diff --git a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/method/MethodToolCallbackProviderTests.java b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/method/MethodToolCallbackProviderTests.java new file mode 100644 index 0000000..76ccded --- /dev/null +++ b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/method/MethodToolCallbackProviderTests.java @@ -0,0 +1,249 @@ +package io.arconia.ai.tools.method; + +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Stream; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import io.arconia.ai.tools.ToolCallback; +import io.arconia.ai.tools.annotation.Tool; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link MethodToolCallbackProvider}. + */ +class MethodToolCallbackProviderTests { + + @Nested + class BuilderValidationTests { + + @Test + void shouldRejectNullToolObjects() { + assertThatThrownBy(() -> MethodToolCallbackProvider.builder().toolObjects(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolObjects cannot be null"); + } + + @Test + void shouldRejectNullToolObjectElements() { + assertThatThrownBy(() -> MethodToolCallbackProvider.builder().toolObjects(new Tools(), null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolObjects cannot contain null elements"); + } + + @Test + void shouldAcceptEmptyToolObjects() { + var provider = MethodToolCallbackProvider.builder().toolObjects().build(); + assertThat(provider.getToolCallbacks()).isEmpty(); + } + + } + + @Test + void shouldProvideToolCallbacksFromObject() { + Tools tools = new Tools(); + MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder().toolObjects(tools).build(); + + ToolCallback[] callbacks = provider.getToolCallbacks(); + + assertThat(callbacks).hasSize(2); + + var callback1 = Stream.of(callbacks).filter(c -> c.getName().equals("testMethod")).findFirst(); + assertThat(callback1).isPresent(); + assertThat(callback1.get().getName()).isEqualTo("testMethod"); + assertThat(callback1.get().getDescription()).isEqualTo("Test description"); + + var callback2 = Stream.of(callbacks).filter(c -> c.getName().equals("testStaticMethod")).findFirst(); + assertThat(callback2).isPresent(); + assertThat(callback2.get().getName()).isEqualTo("testStaticMethod"); + assertThat(callback2.get().getDescription()).isEqualTo("Test description"); + } + + @Test + void shouldProvideToolCallbacksFromMultipleObjects() { + Tools tools1 = new Tools(); + ToolsExtra tools2 = new ToolsExtra(); + + MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder() + .toolObjects(tools1, tools2) + .build(); + + ToolCallback[] callbacks = provider.getToolCallbacks(); + assertThat(callbacks).hasSize(4); // 2 from Tools + 2 from ToolsExtra + + assertThat(Stream.of(callbacks).map(ToolCallback::getName)) + .containsExactlyInAnyOrder("testMethod", "testStaticMethod", "extraMethod1", "extraMethod2"); + } + + @Test + void shouldEnsureUniqueToolNames() { + ToolsWithDuplicates testComponent = new ToolsWithDuplicates(); + MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder().toolObjects(testComponent).build(); + + assertThatThrownBy(provider::getToolCallbacks).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Multiple tools with the same name (testMethod) found in sources: " + + testComponent.getClass().getName()); + } + + @Test + void shouldHandleToolMethodsWithDifferentVisibility() { + ToolsWithVisibility tools = new ToolsWithVisibility(); + MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder() + .toolObjects(tools) + .build(); + + ToolCallback[] callbacks = provider.getToolCallbacks(); + assertThat(callbacks).hasSize(3); + + assertThat(Stream.of(callbacks).map(ToolCallback::getName)) + .containsExactlyInAnyOrder("publicMethod", "protectedMethod", "privateMethod"); + } + + @Test + void shouldHandleToolMethodsWithDifferentParameters() { + ToolsWithParameters tools = new ToolsWithParameters(); + MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder() + .toolObjects(tools) + .build(); + + ToolCallback[] callbacks = provider.getToolCallbacks(); + assertThat(callbacks).hasSize(3); + + assertThat(Stream.of(callbacks).map(ToolCallback::getName)) + .containsExactlyInAnyOrder("noParams", "oneParam", "multipleParams"); + } + + @Test + void shouldHandleToolMethodsWithDifferentReturnTypes() { + ToolsWithReturnTypes tools = new ToolsWithReturnTypes(); + MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder() + .toolObjects(tools) + .build(); + + ToolCallback[] callbacks = provider.getToolCallbacks(); + assertThat(callbacks).hasSize(4); + + assertThat(Stream.of(callbacks).map(ToolCallback::getName)) + .containsExactlyInAnyOrder("voidMethod", "primitiveMethod", "objectMethod", "collectionMethod"); + } + + static class Tools { + + @Tool("Test description") + static List testStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + List testMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + Function testFunction(String input) { + // This method should be ignored as it's a functional type + return String::length; + } + + @Tool("Test description") + Consumer testConsumer(String input) { + // This method should be ignored as it's a functional type + return System.out::println; + } + + @Tool("Test description") + Supplier testSupplier() { + // This method should be ignored as it's a functional type + return () -> "test"; + } + + void nonToolMethod() { + // This method should be ignored as it doesn't have @Tool annotation + } + } + + static class ToolsExtra { + @Tool("Extra method 1") + String extraMethod1() { + return "extra1"; + } + + @Tool("Extra method 2") + String extraMethod2() { + return "extra2"; + } + } + + static class ToolsWithDuplicates { + @Tool(name = "testMethod", value = "Test description") + List testMethod1(String input) { + return List.of(input); + } + + @Tool(name = "testMethod", value = "Test description") + List testMethod2(String input) { + return List.of(input); + } + } + + static class ToolsWithVisibility { + @Tool("Public method") + public String publicMethod() { + return "public"; + } + + @Tool("Protected method") + protected String protectedMethod() { + return "protected"; + } + + @Tool("Private method") + private String privateMethod() { + return "private"; + } + } + + static class ToolsWithParameters { + @Tool("No parameters") + String noParams() { + return "no params"; + } + + @Tool("One parameter") + String oneParam(String param) { + return param; + } + + @Tool("Multiple parameters") + String multipleParams(String param1, int param2, boolean param3) { + return param1 + param2 + param3; + } + } + + static class ToolsWithReturnTypes { + @Tool("Void method") + void voidMethod() { + } + + @Tool("Primitive method") + int primitiveMethod() { + return 42; + } + + @Tool("Object method") + String objectMethod() { + return "object"; + } + + @Tool("Collection method") + List collectionMethod() { + return List.of("collection"); + } + } +} diff --git a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/method/MethodToolCallbackTests.java b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/method/MethodToolCallbackTests.java new file mode 100644 index 0000000..6c18c3e --- /dev/null +++ b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/method/MethodToolCallbackTests.java @@ -0,0 +1,360 @@ +package io.arconia.ai.tools.method; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.core.type.TypeReference; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.util.ReflectionUtils; + +import io.arconia.ai.tools.annotation.Tool; +import io.arconia.ai.tools.definition.ToolDefinition; +import io.arconia.ai.tools.execution.ToolExecutionException; +import io.arconia.ai.tools.json.JsonParser; +import io.arconia.ai.tools.metadata.ToolMetadata; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link MethodToolCallback}. + */ +class MethodToolCallbackTests { + + @ParameterizedTest + @ValueSource(strings = { + "publicStaticMethod", + "privateStaticMethod", + "packageStaticMethod", + "publicMethod", + "privateMethod", + "packageMethod" + }) + void shouldCallToolFromPublicClass(String methodName) { + validateAssertions(methodName, new PublicTools()); + } + + @ParameterizedTest + @ValueSource(strings = { + "publicStaticMethod", + "privateStaticMethod", + "packageStaticMethod", + "publicMethod", + "privateMethod", + "packageMethod" + }) + void shouldCallToolFromPrivateClass(String methodName) { + validateAssertions(methodName, new PrivateTools()); + } + + @ParameterizedTest + @ValueSource(strings = { + "publicStaticMethod", + "privateStaticMethod", + "packageStaticMethod", + "publicMethod", + "privateMethod", + "packageMethod" + }) + void shouldCallToolFromPackageClass(String methodName) { + validateAssertions(methodName, new PackageTools()); + } + + @Test + void shouldHandleToolContextWhenSupported() { + Method toolMethod = getMethod("methodWithToolContext", ToolContextTools.class); + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(ToolDefinition.from(toolMethod)) + .toolMetadata(ToolMetadata.from(toolMethod)) + .toolMethod(toolMethod) + .toolObject(new ToolContextTools()) + .build(); + + ToolContext toolContext = new ToolContext(Map.of("key", "value")); + String result = callback.call(""" + { + "input": "test" + } + """, toolContext); + + assertThat(result).contains("value"); + } + + @Test + void shouldThrowExceptionWhenToolContextNotSupported() { + Method toolMethod = getMethod("publicMethod", PublicTools.class); + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(ToolDefinition.from(toolMethod)) + .toolMetadata(ToolMetadata.from(toolMethod)) + .toolMethod(toolMethod) + .toolObject(new PublicTools()) + .build(); + + ToolContext toolContext = new ToolContext(Map.of("key", "value")); + + assertThatThrownBy(() -> callback.call(""" + { + "input": "test" + } + """, toolContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("ToolContext is not supported"); + } + + @Test + void shouldHandleComplexArguments() { + Method toolMethod = getMethod("complexArgumentMethod", ComplexTools.class); + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(ToolDefinition.from(toolMethod)) + .toolMetadata(ToolMetadata.from(toolMethod)) + .toolMethod(toolMethod) + .toolObject(new ComplexTools()) + .build(); + + String result = callback.call(""" + { + "stringArg": "test", + "intArg": 42, + "listArg": ["a", "b", "c"], + "optionalArg": null + } + """); + + assertThat(JsonParser.fromJson(result, new TypeReference>() {})) + .containsEntry("stringValue", "test") + .containsEntry("intValue", 42) + .containsEntry("listSize", 3); + } + + @Test + void shouldHandleCustomResultConverter() { + Method toolMethod = getMethod("publicMethod", PublicTools.class); + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(ToolDefinition.from(toolMethod)) + .toolMetadata(ToolMetadata.from(toolMethod)) + .toolMethod(toolMethod) + .toolObject(new PublicTools()) + .toolCallResultConverter((result, type) -> "Converted: " + result) + .build(); + + String result = callback.call(""" + { + "input": "test" + } + """); + + assertThat(result).startsWith("Converted:"); + } + + @Test + void shouldHandleMethodExecutionError() { + Method toolMethod = getMethod("errorMethod", ErrorTools.class); + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(ToolDefinition.from(toolMethod)) + .toolMetadata(ToolMetadata.from(toolMethod)) + .toolMethod(toolMethod) + .toolObject(new ErrorTools()) + .build(); + + String result = callback.call(""" + { + "input": "test" + } + """); + + assertThat(result) + .contains("Test error"); + } + + @Test + void shouldThrowExceptionWhenExecutionErrorAndReturnDirect() { + Method toolMethod = getMethod("errorMethodDirect", ErrorTools.class); + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(ToolDefinition.from(toolMethod)) + .toolMetadata(ToolMetadata.from(toolMethod)) + .toolMethod(toolMethod) + .toolObject(new ErrorTools()) + .build(); + + assertThatThrownBy(() -> callback.call(""" + { + "input": "test" + } + """)) + .isInstanceOf(ToolExecutionException.class) + .hasMessageContaining("Test error"); + } + + private static void validateAssertions(String methodName, Object toolObject) { + Method toolMethod = getMethod(methodName, toolObject.getClass()); + assertThat(toolMethod).isNotNull(); + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(ToolDefinition.from(toolMethod)) + .toolMetadata(ToolMetadata.from(toolMethod)) + .toolMethod(toolMethod) + .toolObject(toolObject) + .build(); + + String result = callback.call(""" + { + "input": "Wingardium Leviosa" + } + """); + + assertThat(JsonParser.fromJson(result, new TypeReference>() {})) + .contains("Wingardium Leviosa"); + } + + private static Method getMethod(String name, Class toolsClass) { + return Arrays.stream(ReflectionUtils.getDeclaredMethods(toolsClass)) + .filter(m -> m.getName().equals(name)) + .findFirst() + .orElseThrow(); + } + + static public class PublicTools { + + @Tool("Test description") + public static List publicStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + private static List privateStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + static List packageStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + public List publicMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + private List privateMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + List packageMethod(String input) { + return List.of(input); + } + + } + + static private class PrivateTools { + + @Tool("Test description") + public static List publicStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + private static List privateStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + static List packageStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + public List publicMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + private List privateMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + List packageMethod(String input) { + return List.of(input); + } + + } + + static class PackageTools { + + @Tool("Test description") + public static List publicStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + private static List privateStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + static List packageStaticMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + public List publicMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + private List privateMethod(String input) { + return List.of(input); + } + + @Tool("Test description") + List packageMethod(String input) { + return List.of(input); + } + + } + + static class ToolContextTools { + + @Tool("Test description") + public String methodWithToolContext(String input, ToolContext toolContext) { + return input + ": " + toolContext.getContext().get("key"); + } + + } + + static class ComplexTools { + + @Tool("Test description") + public Map complexArgumentMethod(String stringArg, int intArg, List listArg, String optionalArg) { + return Map.of( + "stringValue", stringArg, + "intValue", intArg, + "listSize", listArg.size(), + "optionalProvided", optionalArg != null + ); + } + + } + + static class ErrorTools { + + @Tool("Test description") + public String errorMethod(String input) { + throw new IllegalArgumentException("Test error"); + } + + @Tool(value = "Test description", returnDirect = true) + public String errorMethodDirect(String input) { + throw new IllegalArgumentException("Test error"); + } + + } + +} diff --git a/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/utils/ToolUtilsTests.java b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/utils/ToolUtilsTests.java new file mode 100644 index 0000000..7e3c730 --- /dev/null +++ b/arconia-ai/arconia-ai-tools/src/test/java/io/arconia/ai/tools/utils/ToolUtilsTests.java @@ -0,0 +1,204 @@ +package io.arconia.ai.tools.utils; + +import java.lang.reflect.Method; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import io.arconia.ai.tools.ToolCallback; +import io.arconia.ai.tools.annotation.Tool; +import io.arconia.ai.tools.definition.ToolDefinition; +import io.arconia.ai.tools.execution.DefaultToolCallResultConverter; +import io.arconia.ai.tools.execution.ToolCallResultConverter; +import io.arconia.ai.tools.execution.ToolExecutionMode; +import io.arconia.ai.tools.util.ToolUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link ToolUtils}. + */ +class ToolUtilsTests { + + @Test + void shouldDetectDuplicateToolNames() { + ToolCallback callback1 = new TestToolCallback("tool_a"); + ToolCallback callback2 = new TestToolCallback("tool_a"); + ToolCallback callback3 = new TestToolCallback("tool_b"); + + List duplicates = ToolUtils.getDuplicateToolNames(callback1, callback2, callback3); + + assertThat(duplicates).isNotEmpty(); + assertThat(duplicates).contains("tool_a"); + } + + @Test + void shouldNotDetectDuplicateToolNames() { + ToolCallback callback1 = new TestToolCallback("tool_a"); + ToolCallback callback2 = new TestToolCallback("tool_b"); + ToolCallback callback3 = new TestToolCallback("tool_c"); + + List duplicates = ToolUtils.getDuplicateToolNames(callback1, callback2, callback3); + + assertThat(duplicates).isEmpty(); + } + + @Test + void shouldGetToolNameFromAnnotation() throws Exception { + Method method = TestTools.class.getMethod("toolWithCustomName"); + assertThat(ToolUtils.getToolName(method)).isEqualTo("customName"); + } + + @Test + void shouldGetMethodNameWhenNoCustomNameInAnnotation() throws Exception { + Method method = TestTools.class.getMethod("toolWithoutCustomName"); + assertThat(ToolUtils.getToolName(method)).isEqualTo("toolWithoutCustomName"); + } + + @Test + void shouldGetMethodNameWhenNoAnnotation() throws Exception { + Method method = TestTools.class.getMethod("methodWithoutAnnotation"); + assertThat(ToolUtils.getToolName(method)).isEqualTo("methodWithoutAnnotation"); + } + + @Test + void shouldGetToolDescriptionFromAnnotation() throws Exception { + Method method = TestTools.class.getMethod("toolWithCustomDescription"); + assertThat(ToolUtils.getToolDescription(method)).isEqualTo("Custom description"); + } + + @Test + void shouldGetMethodNameWhenNoCustomDescriptionInAnnotation() throws Exception { + Method method = TestTools.class.getMethod("toolWithoutCustomDescription"); + assertThat(ToolUtils.getToolDescription(method)).isEqualTo("toolWithoutCustomDescription"); + } + + @Test + void shouldGetFormattedMethodNameWhenNoAnnotation() throws Exception { + Method method = TestTools.class.getMethod("camelCaseMethodWithoutAnnotation"); + assertThat(ToolUtils.getToolDescription(method)).isEqualTo("camel case method without annotation"); + } + + @Test + void shouldGetToolExecutionModeFromAnnotation() throws Exception { + Method method = TestTools.class.getMethod("toolWithCustomExecutionMode"); + assertThat(ToolUtils.getToolExecutionMode(method)).isEqualTo(ToolExecutionMode.BLOCKING); + } + + @Test + void shouldGetDefaultExecutionModeWhenNoAnnotation() throws Exception { + Method method = TestTools.class.getMethod("methodWithoutAnnotation"); + assertThat(ToolUtils.getToolExecutionMode(method)).isEqualTo(ToolExecutionMode.BLOCKING); + } + + @Test + void shouldGetToolReturnDirectFromAnnotation() throws Exception { + Method method = TestTools.class.getMethod("toolWithReturnDirect"); + assertThat(ToolUtils.getToolReturnDirect(method)).isTrue(); + } + + @Test + void shouldGetDefaultReturnDirectWhenNoAnnotation() throws Exception { + Method method = TestTools.class.getMethod("methodWithoutAnnotation"); + assertThat(ToolUtils.getToolReturnDirect(method)).isFalse(); + } + + @Test + void shouldGetToolCallResultConverterFromAnnotation() throws Exception { + Method method = TestTools.class.getMethod("toolWithCustomConverter"); + ToolCallResultConverter converter = ToolUtils.getToolCallResultConverter(method); + assertThat(converter).isInstanceOf(CustomToolCallResultConverter.class); + } + + @Test + void shouldGetDefaultConverterWhenNoAnnotation() throws Exception { + Method method = TestTools.class.getMethod("methodWithoutAnnotation"); + ToolCallResultConverter converter = ToolUtils.getToolCallResultConverter(method); + assertThat(converter).isInstanceOf(DefaultToolCallResultConverter.class); + } + + @Test + void shouldThrowExceptionWhenConverterCannotBeInstantiated() throws Exception { + Method method = TestTools.class.getMethod("toolWithInvalidConverter"); + assertThatThrownBy(() -> ToolUtils.getToolCallResultConverter(method)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Failed to instantiate ToolCallResultConverter"); + } + + static class TestToolCallback implements ToolCallback { + + private final ToolDefinition toolDefinition; + + public TestToolCallback(String name) { + this.toolDefinition = ToolDefinition.builder() + .name(name) + .description(name) + .inputTypeSchema("{}") + .build(); + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String functionInput) { + return ""; + } + } + + static class TestTools { + + @Tool(name = "customName") + public void toolWithCustomName() {} + + @Tool + public void toolWithoutCustomName() {} + + @Tool(value = "Custom description") + public void toolWithCustomDescription() {} + + @Tool + public void toolWithoutCustomDescription() {} + + @Tool(executionMode = ToolExecutionMode.BLOCKING) + public void toolWithCustomExecutionMode() {} + + @Tool(returnDirect = true) + public void toolWithReturnDirect() {} + + @Tool(resultConverter = CustomToolCallResultConverter.class) + public void toolWithCustomConverter() {} + + @Tool(resultConverter = InvalidToolCallResultConverter.class) + public void toolWithInvalidConverter() {} + + public void methodWithoutAnnotation() {} + + public void camelCaseMethodWithoutAnnotation() {} + + } + + public static class CustomToolCallResultConverter implements ToolCallResultConverter { + + @Override + public String apply(Object result, Class returnType) { + return returnType.getName(); + } + + } + + // No-public class with no-public constructor + static class InvalidToolCallResultConverter implements ToolCallResultConverter { + + private InvalidToolCallResultConverter() {} + + @Override + public String apply(Object result, Class returnType) { + return returnType.getName(); + } + + } +}