-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
aef4e94
commit 3670877
Showing
3 changed files
with
64 additions
and
0 deletions.
There are no files selected for viewing
2 changes: 2 additions & 0 deletions
2
...ool-calling/tool-calling-openai/src/main/java/com/thomasvitale/ai/spring/BookService.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
60 changes: 60 additions & 0 deletions
60
...ing-openai/src/main/java/com/thomasvitale/ai/spring/ToolBeanRegistrationAotProcessor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
package com.thomasvitale.ai.spring; | ||
|
||
import org.springframework.ai.tool.annotation.Tool; | ||
import org.springframework.aot.generate.GenerationContext; | ||
import org.springframework.aot.hint.MemberCategory; | ||
import org.springframework.aot.hint.ReflectionHints; | ||
import org.springframework.beans.factory.aot.BeanRegistrationAotContribution; | ||
import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor; | ||
import org.springframework.beans.factory.aot.BeanRegistrationCode; | ||
import org.springframework.beans.factory.support.RegisteredBean; | ||
import org.springframework.core.annotation.MergedAnnotations; | ||
import org.springframework.lang.Nullable; | ||
import org.springframework.util.ReflectionUtils; | ||
|
||
import java.util.stream.Stream; | ||
|
||
import static org.springframework.core.annotation.MergedAnnotations.SearchStrategy.TYPE_HIERARCHY; | ||
|
||
/** | ||
* AOT {@code BeanRegistrationAotProcessor} that detects the presence of the | ||
* {@link Tool} annotation on methods and creates the required reflection hints. | ||
*/ | ||
class ToolBeanRegistrationAotProcessor implements BeanRegistrationAotProcessor { | ||
|
||
@Override | ||
@Nullable | ||
public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) { | ||
Class<?> beanClass = registeredBean.getBeanClass(); | ||
MergedAnnotations.Search search = MergedAnnotations.search(TYPE_HIERARCHY); | ||
|
||
boolean hasAnyToolAnnotatedMethods = Stream.of(ReflectionUtils.getDeclaredMethods(beanClass)) | ||
.anyMatch(method -> search.from(method).isPresent(Tool.class)); | ||
|
||
if (hasAnyToolAnnotatedMethods) { | ||
return new AotContribution(beanClass); | ||
} | ||
|
||
return null; | ||
} | ||
|
||
private static class AotContribution implements BeanRegistrationAotContribution { | ||
|
||
private final MemberCategory[] memberCategories = new MemberCategory[] { MemberCategory.INVOKE_DECLARED_METHODS, | ||
MemberCategory.INVOKE_PUBLIC_METHODS }; | ||
|
||
private final Class<?> toolClass; | ||
|
||
public AotContribution(Class<?> toolClass) { | ||
this.toolClass = toolClass; | ||
} | ||
|
||
@Override | ||
public void applyTo(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode) { | ||
ReflectionHints reflectionHints = generationContext.getRuntimeHints().reflection(); | ||
reflectionHints.registerType(toolClass, memberCategories); | ||
} | ||
|
||
} | ||
|
||
} |
2 changes: 2 additions & 0 deletions
2
patterns/tool-calling/tool-calling-openai/src/main/resources/META-INF/spring/aot.factories
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
org.springframework.beans.factory.aot.BeanRegistrationAotProcessor=\ | ||
com.thomasvitale.ai.spring.ToolBeanRegistrationAotProcessor |