Skip to content

Commit

Permalink
feat(ai): Extend support for tools from objects, classes, functions, …
Browse files Browse the repository at this point in the history
…and MCP
  • Loading branch information
ThomasVitale committed Jan 3, 2025
1 parent 981cf00 commit 628cab2
Show file tree
Hide file tree
Showing 17 changed files with 470 additions and 272 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

import io.arconia.ai.core.tools.ToolCallbackResolver;
import io.arconia.ai.core.tools.ToolCallbackProvider;

/**
* A {@link ChatClient} enhanced for more advanced features.
Expand Down Expand Up @@ -133,7 +133,7 @@ interface ArconiaChatClientRequestSpec extends ChatClientRequestSpec {

ArconiaChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks);

ArconiaChatClientRequestSpec toolCallbackResolvers(ToolCallbackResolver... toolCallbackResolvers);
ArconiaChatClientRequestSpec toolCallbackProviders(ToolCallbackProvider... toolCallbackProviders);

ArconiaChatClientRequestSpec functions(FunctionCallback... functionCallbacks);

Expand Down Expand Up @@ -215,7 +215,7 @@ interface ArconiaBuilder extends Builder {

ArconiaBuilder defaultToolCallbacks(FunctionCallback... toolCallbacks);

ArconiaBuilder defaultToolCallbackResolvers(ToolCallbackResolver... toolCallbackResolvers);
ArconiaBuilder defaultToolCallbackProviders(ToolCallbackProvider... toolCallbackProviders);

ArconiaBuilder defaultFunctions(String... functionNames);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.stream.Stream;

import io.micrometer.observation.ObservationRegistry;

Expand Down Expand Up @@ -44,8 +43,8 @@
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;

import io.arconia.ai.core.tools.ToolCallbackResolver;
import io.arconia.ai.core.tools.method.MethodToolCallbackResolver;
import io.arconia.ai.core.tools.ToolCallbackProvider;
import io.arconia.ai.core.tools.method.MethodToolCallbackProvider;

/**
* Default implementation of {@link ArconiaChatClient} based on {@link DefaultChatClient}.
Expand Down Expand Up @@ -607,20 +606,16 @@ public ArconiaChatClientRequestSpec tools(String... toolNames) {
public ArconiaChatClientRequestSpec tools(Class<?>... toolBoxes) {
Assert.notNull(toolBoxes, "toolBoxes cannot be null");
Assert.noNullElements(toolBoxes, "toolBoxes cannot contain null elements");
ToolCallbackResolver[] toolCallbackResolvers = Stream.of(toolBoxes)
.map(toolBox -> MethodToolCallbackResolver.builder().type(toolBox).build())
.toArray(ToolCallbackResolver[]::new);
return toolCallbackResolvers(toolCallbackResolvers);
ToolCallbackProvider toolCallbackProvider = MethodToolCallbackProvider.builder().sources(toolBoxes).build();
return toolCallbackProviders(toolCallbackProvider);
}

@Override
public ArconiaChatClientRequestSpec tools(Object... toolBoxes) {
Assert.notNull(toolBoxes, "toolBoxes cannot be null");
Assert.noNullElements(toolBoxes, "toolBoxes cannot contain null elements");
ToolCallbackResolver[] toolCallbackResolvers = Stream.of(toolBoxes)
.map(toolBox -> MethodToolCallbackResolver.builder().object(toolBox).build())
.toArray(ToolCallbackResolver[]::new);
return toolCallbackResolvers(toolCallbackResolvers);
ToolCallbackProvider toolCallbackProvider = MethodToolCallbackProvider.builder().sources(toolBoxes).build();
return toolCallbackProviders(toolCallbackProvider);
}

@Override
Expand All @@ -632,9 +627,9 @@ public ArconiaChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallba
}

@Override
public ArconiaChatClientRequestSpec toolCallbackResolvers(ToolCallbackResolver... toolCallbackResolvers) {
for (ToolCallbackResolver toolCallbackResolver : toolCallbackResolvers) {
this.toolCallbacks.addAll(Arrays.asList(toolCallbackResolver.getToolCallbacks()));
public ArconiaChatClientRequestSpec toolCallbackProviders(ToolCallbackProvider... toolCallbackProviders) {
for (ToolCallbackProvider toolCallbackProvider : toolCallbackProviders) {
this.toolCallbacks.addAll(Arrays.asList(toolCallbackProvider.getToolCallbacks()));
}
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.stream.Stream;

import io.micrometer.observation.ObservationRegistry;

Expand All @@ -21,8 +20,8 @@
import org.springframework.util.Assert;

import io.arconia.ai.core.client.DefaultArconiaChatClient.DefaultArconiaChatClientRequestSpec;
import io.arconia.ai.core.tools.ToolCallbackResolver;
import io.arconia.ai.core.tools.method.MethodToolCallbackResolver;
import io.arconia.ai.core.tools.ToolCallbackProvider;
import io.arconia.ai.core.tools.method.MethodToolCallbackProvider;

/**
* Default implementation of {@link ArconiaChatClient.ArconiaBuilder} based on
Expand Down Expand Up @@ -143,18 +142,14 @@ public ArconiaChatClient.ArconiaBuilder defaultTools(String... toolNames) {

@Override
public ArconiaChatClient.ArconiaBuilder defaultTools(Class<?>... toolBoxes) {
ToolCallbackResolver[] toolCallbackResolvers = Stream.of(toolBoxes)
.map(toolBox -> MethodToolCallbackResolver.builder().type(toolBox).build())
.toArray(ToolCallbackResolver[]::new);
return defaultToolCallbackResolvers(toolCallbackResolvers);
ToolCallbackProvider toolCallbackProvider = MethodToolCallbackProvider.builder().sources(toolBoxes).build();
return defaultToolCallbackProviders(toolCallbackProvider);
}

@Override
public ArconiaChatClient.ArconiaBuilder defaultTools(Object... toolBoxes) {
ToolCallbackResolver[] toolCallbackResolvers = Stream.of(toolBoxes)
.map(toolBox -> MethodToolCallbackResolver.builder().object(toolBox).build())
.toArray(ToolCallbackResolver[]::new);
return defaultToolCallbackResolvers(toolCallbackResolvers);
ToolCallbackProvider toolCallbackProvider = MethodToolCallbackProvider.builder().sources(toolBoxes).build();
return defaultToolCallbackProviders(toolCallbackProvider);
}

@Override
Expand All @@ -164,10 +159,10 @@ public ArconiaChatClient.ArconiaBuilder defaultToolCallbacks(FunctionCallback...
}

@Override
public ArconiaChatClient.ArconiaBuilder defaultToolCallbackResolvers(
ToolCallbackResolver... toolCallbackResolvers) {
for (ToolCallbackResolver toolCallbackResolver : toolCallbackResolvers) {
this.arconiaRequest.functions(toolCallbackResolver.getToolCallbacks());
public ArconiaChatClient.ArconiaBuilder defaultToolCallbackProviders(
ToolCallbackProvider... toolCallbackProviders) {
for (ToolCallbackProvider toolCallbackProvider : toolCallbackProviders) {
this.arconiaRequest.functions(toolCallbackProvider.getToolCallbacks());
}
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import org.springframework.ai.model.function.MethodInvokingFunctionCallback;
import org.springframework.ai.util.ParsingUtils;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

Expand All @@ -32,24 +34,23 @@ public class ArconiaToolCallbackBuilder implements FunctionCallback.Builder {
private final static Logger logger = LoggerFactory.getLogger(ArconiaToolCallbackBuilder.class);

@Override
public <I, O> FunctionCallback.FunctionInvokingSpec<I, O> function(String name, Function<I, O> function) {
public <I, O> ArconiaFunctionInvokingSpec<I, O> function(String name, Function<I, O> function) {
return new ArconiaFunctionInvokingSpec<>(name, function);
}

@Override
public <I, O> FunctionCallback.FunctionInvokingSpec<I, O> function(String name,
BiFunction<I, ToolContext, O> biFunction) {
public <I, O> ArconiaFunctionInvokingSpec<I, O> function(String name, BiFunction<I, ToolContext, O> biFunction) {
return new ArconiaFunctionInvokingSpec<>(name, biFunction);
}

@Override
public <O> FunctionCallback.FunctionInvokingSpec<Void, O> function(String name, Supplier<O> supplier) {
public <O> ArconiaFunctionInvokingSpec<Void, O> function(String name, Supplier<O> supplier) {
Function<Void, O> function = input -> supplier.get();
return new ArconiaFunctionInvokingSpec<>(name, function).inputType(Void.class);
}

@Override
public <I> FunctionCallback.FunctionInvokingSpec<I, Void> function(String name, Consumer<I> consumer) {
public <I> ArconiaFunctionInvokingSpec<I, Void> function(String name, Consumer<I> consumer) {
Function<I, Void> function = (I input) -> {
consumer.accept(input);
return null;
Expand All @@ -58,14 +59,14 @@ public <I> FunctionCallback.FunctionInvokingSpec<I, Void> function(String name,
}

@Override
public FunctionCallback.MethodInvokingSpec method(String methodName, Class<?>... argumentTypes) {
public ArconiaMethodInvokingSpec method(String methodName, Class<?>... argumentTypes) {
throw new UnsupportedOperationException("Use the 'method(Method method)' method instead");
}

/**
* Create a {@link FunctionCallback.MethodInvokingSpec} for the given method.
*/
public FunctionCallback.MethodInvokingSpec method(Method method) {
public ArconiaMethodInvokingSpec method(Method method) {
return new ArconiaMethodInvokingSpec(method);
}

Expand All @@ -82,27 +83,29 @@ private String generateDescription(String fromName) {
/**
* Arconia {@link FunctionCallback.FunctionInvokingSpec} implementation.
*/
final class ArconiaFunctionInvokingSpec<I, O>
public final class ArconiaFunctionInvokingSpec<I, O>
extends DefaultCommonCallbackInvokingSpec<FunctionCallback.FunctionInvokingSpec<I, O>>
implements FunctionCallback.FunctionInvokingSpec<I, O> {

private final String name;

private Type inputType;

@Nullable
private final BiFunction<I, ToolContext, O> biFunction;

@Nullable
private final Function<I, O> function;

private ArconiaFunctionInvokingSpec(String name, BiFunction<I, ToolContext, O> biFunction) {
private Type inputType;

private ArconiaFunctionInvokingSpec(String name, @NonNull BiFunction<I, ToolContext, O> biFunction) {
Assert.hasText(name, "name cannot be null or empty");
Assert.notNull(biFunction, "biFunction cannot be null");
this.name = name;
this.biFunction = biFunction;
this.function = null;
}

private ArconiaFunctionInvokingSpec(String name, Function<I, O> function) {
private ArconiaFunctionInvokingSpec(String name, @NonNull Function<I, O> function) {
Assert.hasText(name, "name cannot be null or empty");
Assert.notNull(function, "function cannot be null");
this.name = name;
Expand All @@ -111,24 +114,21 @@ private ArconiaFunctionInvokingSpec(String name, Function<I, O> function) {
}

@Override
public FunctionCallback.FunctionInvokingSpec<I, O> inputType(Class<?> inputType) {
public ArconiaFunctionInvokingSpec<I, O> inputType(Class<?> inputType) {
Assert.notNull(inputType, "inputType cannot be null");
this.inputType = inputType;
return this;
}

@Override
public FunctionCallback.FunctionInvokingSpec<I, O> inputType(ParameterizedTypeReference<?> inputType) {
public ArconiaFunctionInvokingSpec<I, O> inputType(ParameterizedTypeReference<?> inputType) {
Assert.notNull(inputType, "inputType cannot be null");
this.inputType = inputType.getType();
return this;
}

@Override
public FunctionCallback build() {
Assert.notNull(this.getObjectMapper(), "objectMapper cannot be null");
Assert.hasText(this.name, "name cannot be null or empty");
Assert.notNull(this.getResponseConverter(), "responseConverter cannot be null");
Assert.notNull(this.inputType, "inputType cannot be null");

if (this.getInputTypeSchema() == null) {
Expand All @@ -143,86 +143,103 @@ public FunctionCallback build() {
var constructor = FunctionInvokingFunctionCallback.class.getDeclaredConstructor(String.class,
String.class, String.class, Type.class, Function.class, ObjectMapper.class, BiFunction.class);
constructor.setAccessible(true);
return constructor.newInstance(this.name, this.getDescriptionExt(), this.getInputTypeSchema(),
return constructor.newInstance(this.name, this.getToolDescription(), this.getInputTypeSchema(),
this.inputType, this.getResponseConverter(), this.getObjectMapper(), finalBiFunction);
}
catch (ReflectiveOperationException ex) {
throw new IllegalStateException("Failed to create FunctionInvokingFunctionCallback instance", ex);
}
}

private String getDescriptionExt() {
private String getToolDescription() {
if (StringUtils.hasText(this.getDescription())) {
return this.getDescription();
}
return generateDescription(this.name);
return ParsingUtils.reConcatenateCamelCase(this.name, " ");
}

}

/**
* Arconia {@link FunctionCallback.MethodInvokingSpec} implementation.
*/
final class ArconiaMethodInvokingSpec extends DefaultCommonCallbackInvokingSpec<FunctionCallback.MethodInvokingSpec>
public static final class ArconiaMethodInvokingSpec
extends DefaultCommonCallbackInvokingSpec<FunctionCallback.MethodInvokingSpec>
implements FunctionCallback.MethodInvokingSpec {

private final Method method;

private String name;

private Class<?> targetClass;

private Object targetObject;
@Nullable
private Object source;

private ArconiaMethodInvokingSpec(Method method) {
Assert.notNull(method, "method cannot be null");
Assert.hasText(method.getName(), "method name cannot be null or empty");

this.method = method;
this.name = getToolName(method.getAnnotation(Tool.class), method.getName());
this.description = getToolDescription(method.getAnnotation(Tool.class), method.getName());
this.schemaType = getToolSchemaType(method.getAnnotation(Tool.class));
}

@Override
public FunctionCallback.MethodInvokingSpec name(String name) {
public ArconiaMethodInvokingSpec name(String name) {
Assert.hasText(name, "name cannot be null or empty");
this.name = name;
return this;
}

@Override
public FunctionCallback.MethodInvokingSpec targetClass(Class<?> targetClass) {
Assert.notNull(targetClass, "targetClass cannot be null");
this.targetClass = targetClass;
public ArconiaMethodInvokingSpec source(Object source) {
Assert.notNull(source, "source cannot be null");
this.source = source;
return this;
}

@Override
public FunctionCallback.MethodInvokingSpec targetObject(Object methodObject) {
Assert.notNull(methodObject, "methodObject cannot be null");
this.targetObject = methodObject;
this.targetClass = methodObject.getClass();
public ArconiaMethodInvokingSpec targetClass(Class<?> targetClass) {
return this;
}

@Override
public FunctionCallback build() {
Assert.isTrue(this.targetClass != null || this.targetObject != null,
"targetClass or targetObject cannot be null");
public ArconiaMethodInvokingSpec targetObject(Object targetObject) {
return source(targetObject);
}

@Override
public FunctionCallback build() {
try {
var constructor = MethodInvokingFunctionCallback.class.getDeclaredConstructor(Object.class,
Method.class, String.class, ObjectMapper.class, String.class, Function.class);
constructor.setAccessible(true);
return constructor.newInstance(this.targetObject, method, this.getDescriptionExt(),
this.getObjectMapper(), this.name, this.getResponseConverter());
return constructor.newInstance(this.source, method, this.getDescription(), this.getObjectMapper(),
this.name, this.getResponseConverter());
}
catch (ReflectiveOperationException ex) {
throw new IllegalStateException("Failed to create MethodInvokingFunctionCallback instance", ex);
}
}

private String getDescriptionExt() {
if (StringUtils.hasText(this.getDescription())) {
return this.getDescription();
private static String getToolName(@Nullable Tool tool, String methodName) {
if (tool == null) {
return methodName;
}
return StringUtils.hasText(tool.name()) ? tool.name() : methodName;
}

private static String getToolDescription(@Nullable Tool tool, String methodName) {
if (tool == null) {
return ParsingUtils.reConcatenateCamelCase(methodName, " ");
}
return StringUtils.hasText(tool.value()) ? tool.value() : methodName;
}

private static FunctionCallback.SchemaType getToolSchemaType(@Nullable Tool tool) {
if (tool == null) {
return FunctionCallback.SchemaType.JSON_SCHEMA;
}
return generateDescription(StringUtils.hasText(this.name) ? this.name : this.method.getName());
return tool.schemaType();
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
/**
* Wrapper for {@link FunctionCallback} to identify tools in Spring AI.
*/
public interface ToolCallback extends FunctionCallback {
public interface ToolCallback extends FunctionCallback, ToolMetadata {

/**
* Creates a new {@link FunctionCallback.Builder} instance.
Expand Down
Loading

0 comments on commit 628cab2

Please sign in to comment.