Skip to content

Commit

Permalink
First version
Browse files Browse the repository at this point in the history
  • Loading branch information
andreadimaio committed Jun 6, 2024
1 parent 9102884 commit 2e3d668
Show file tree
Hide file tree
Showing 30 changed files with 3,403 additions and 111 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.quarkiverse.langchain4j.deployment;

import io.quarkus.builder.item.SimpleBuildItem;

public final class AiCacheBuildItem extends SimpleBuildItem {

private boolean enable;

public AiCacheBuildItem(boolean enable) {
this.enable = enable;
}

public boolean isEnable() {
return enable;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package io.quarkiverse.langchain4j.deployment;

import jakarta.enterprise.context.ApplicationScoped;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationTarget;
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.ClassType;
import org.jboss.jandex.IndexView;

import io.quarkiverse.langchain4j.runtime.AiCacheRecorder;
import io.quarkiverse.langchain4j.runtime.cache.AiCacheConfig;
import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider;
import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.CombinedIndexBuildItem;

public class AiCacheProcessor {

@BuildStep
@Record(ExecutionTime.RUNTIME_INIT)
void setupBeans(ChatMemoryBuildConfig buildConfig, AiCacheConfig cacheConfig,
AiCacheRecorder recorder,
CombinedIndexBuildItem indexBuildItem,
BuildProducer<AiCacheBuildItem> aiCacheBuildItemProducer,
BuildProducer<UnremovableBeanBuildItem> unremovableProducer,
BuildProducer<SyntheticBeanBuildItem> syntheticBeanProducer) {

IndexView index = indexBuildItem.getIndex();
boolean enableCache = false;

for (AnnotationInstance instance : index.getAnnotations(LangChain4jDotNames.REGISTER_AI_SERVICES)) {
if (instance.target().kind() != AnnotationTarget.Kind.CLASS) {
continue;
}

ClassInfo declarativeAiServiceClassInfo = instance.target().asClass();

if (declarativeAiServiceClassInfo.hasAnnotation(LangChain4jDotNames.CACHE_RESULT)) {
enableCache = true;
break;
}
}

aiCacheBuildItemProducer.produce(new AiCacheBuildItem(enableCache));

if (enableCache) {
var configurator = SyntheticBeanBuildItem
.configure(AiCacheProvider.class)
.setRuntimeInit()
.addInjectionPoint(ClassType.create(AiCacheStore.class))
.scope(ApplicationScoped.class)
.createWith(recorder.messageWindow(cacheConfig))
.defaultBean();

syntheticBeanProducer.produce(configurator.done());
unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(AiCacheStore.class));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import org.objectweb.asm.tree.analysis.AnalyzerException;

import dev.langchain4j.exception.IllegalConfigurationException;
import dev.langchain4j.service.Moderate;
import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.ToolBox;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
Expand Down Expand Up @@ -185,6 +184,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,

Set<String> chatModelNames = new HashSet<>();
Set<String> moderationModelNames = new HashSet<>();

for (AnnotationInstance instance : index.getAnnotations(LangChain4jDotNames.REGISTER_AI_SERVICES)) {
if (instance.target().kind() != AnnotationTarget.Kind.CLASS) {
continue; // should never happen
Expand All @@ -206,14 +206,12 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}

String chatModelName = NamedConfigUtil.DEFAULT_NAME;
String moderationModelName = NamedConfigUtil.DEFAULT_NAME;
String embeddingModelName = getModelName(instance.value("modelName"));

if (chatLanguageModelSupplierClassDotName == null) {
AnnotationValue modelNameValue = instance.value("modelName");
if (modelNameValue != null) {
String modelNameValueStr = modelNameValue.asString();
if ((modelNameValueStr != null) && !modelNameValueStr.isEmpty()) {
chatModelName = modelNameValueStr;
}
}
chatModelName = getModelName(modelNameValue);
chatModelNames.add(chatModelName);
}

Expand All @@ -239,6 +237,18 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}
}

// the default value depends on whether tools exists or not - if they do, then we require a AiCacheProvider bean
DotName aiCacheProviderSupplierClassDotName = LangChain4jDotNames.BEAN_AI_CACHE_PROVIDER_SUPPLIER;
AnnotationValue aiCacheProviderSupplierValue = instance.value("cacheProviderSupplier");
if (aiCacheProviderSupplierValue != null) {
aiCacheProviderSupplierClassDotName = aiCacheProviderSupplierValue.asClass().name();
if (!aiCacheProviderSupplierClassDotName
.equals(LangChain4jDotNames.BEAN_AI_CACHE_PROVIDER_SUPPLIER)) {
validateSupplierAndRegisterForReflection(aiCacheProviderSupplierClassDotName, index,
reflectiveClassProducer);
}
}

DotName retrieverClassDotName = null;
AnnotationValue retrieverValue = instance.value("retriever");
if (retrieverValue != null) {
Expand Down Expand Up @@ -292,17 +302,11 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}

// determine whether the method is annotated with @Moderate
String moderationModelName = NamedConfigUtil.DEFAULT_NAME;
for (MethodInfo method : declarativeAiServiceClassInfo.methods()) {
if (method.hasAnnotation(LangChain4jDotNames.MODERATE)) {
if (moderationModelSupplierClassName.equals(LangChain4jDotNames.BEAN_IF_EXISTS_MODERATION_MODEL_SUPPLIER)) {
AnnotationValue modelNameValue = instance.value("modelName");
if (modelNameValue != null) {
String modelNameValueStr = modelNameValue.asString();
if ((modelNameValueStr != null) && !modelNameValueStr.isEmpty()) {
moderationModelName = modelNameValueStr;
}
}
moderationModelName = getModelName(modelNameValue);
moderationModelNames.add(moderationModelName);
}
break;
Expand All @@ -321,13 +325,16 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
chatLanguageModelSupplierClassDotName,
toolDotNames,
chatMemoryProviderSupplierClassDotName,
aiCacheProviderSupplierClassDotName,
retrieverClassDotName,
retrievalAugmentorSupplierClassName,
customRetrievalAugmentorSupplierClassIsABean,
auditServiceSupplierClassName,
moderationModelSupplierClassName,
cdiScope,
chatModelName, moderationModelName));
chatModelName,
moderationModelName,
embeddingModelName));
}

for (String chatModelName : chatModelNames) {
Expand Down Expand Up @@ -361,7 +368,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
List<DeclarativeAiServiceBuildItem> declarativeAiServiceItems,
List<SelectedChatModelProviderBuildItem> selectedChatModelProvider,
BuildProducer<SyntheticBeanBuildItem> syntheticBeanProducer,
BuildProducer<UnremovableBeanBuildItem> unremoveableProducer) {
BuildProducer<UnremovableBeanBuildItem> unremoveableProducer,
AiCacheBuildItem aiCacheBuildItem) {

boolean needsChatModelBean = false;
boolean needsStreamingChatModelBean = false;
Expand All @@ -370,6 +378,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
boolean needsRetrievalAugmentorBean = false;
boolean needsAuditServiceBean = false;
boolean needsModerationModelBean = false;
boolean needsAiCacheProvider = false;

Set<DotName> allToolNames = new HashSet<>();

for (DeclarativeAiServiceBuildItem bi : declarativeAiServiceItems) {
Expand All @@ -386,6 +396,10 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
? bi.getChatMemoryProviderSupplierClassDotName().toString()
: null;

String aiCacheProviderSupplierClassName = bi.getAiCacheProviderSupplierClassDotName() != null
? bi.getAiCacheProviderSupplierClassDotName().toString()
: null;

String retrieverClassName = bi.getRetrieverClassDotName() != null
? bi.getRetrieverClassDotName().toString()
: null;
Expand All @@ -403,7 +417,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
: null);

// determine whether the method returns Multi<String>
boolean injectStreamingChatModelBean = false;
boolean needsStreamingChatModel = false;
for (MethodInfo method : declarativeAiServiceClassInfo.methods()) {
if (!LangChain4jDotNames.MULTI.equals(method.returnType().name())) {
continue;
Expand All @@ -419,29 +433,36 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
throw illegalConfiguration("Only Multi<String> is supported as a Multi return type. Offending method is '"
+ method.declaringClass().name().toString() + "#" + method.name() + "'");
}
injectStreamingChatModelBean = true;
needsStreamingChatModel = true;
}

boolean injectModerationModelBean = false;
boolean needsModerationModel = false;
for (MethodInfo method : declarativeAiServiceClassInfo.methods()) {
if (method.hasAnnotation(Moderate.class)) {
injectModerationModelBean = true;
if (method.hasAnnotation(LangChain4jDotNames.MODERATE)) {
needsModerationModel = true;
break;
}
}

String chatModelName = bi.getChatModelName();
String moderationModelName = bi.getModerationModelName();
String embeddingModelName = bi.getEmbeddingModelName();
boolean enableCache = aiCacheBuildItem.isEnable();

SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem
.configure(QuarkusAiServiceContext.class)
.forceApplicationClass()
.createWith(recorder.createDeclarativeAiService(
new DeclarativeAiServiceCreateInfo(serviceClassName, chatLanguageModelSupplierClassName,
toolClassNames, chatMemoryProviderSupplierClassName, retrieverClassName,
toolClassNames, chatMemoryProviderSupplierClassName, aiCacheProviderSupplierClassName,
retrieverClassName,
retrievalAugmentorSupplierClassName,
auditServiceClassSupplierName, moderationModelSupplierClassName, chatModelName,
moderationModelName,
injectStreamingChatModelBean, injectModerationModelBean)))
embeddingModelName,
needsStreamingChatModel,
needsModerationModel,
enableCache)))
.setRuntimeInit()
.addQualifier()
.annotation(LangChain4jDotNames.QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER).addValue("value", serviceClassName)
Expand All @@ -451,15 +472,15 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
if ((chatLanguageModelSupplierClassName == null) && !selectedChatModelProvider.isEmpty()) {
if (NamedConfigUtil.isDefault(chatModelName)) {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.CHAT_MODEL));
if (injectStreamingChatModelBean) {
if (needsStreamingChatModel) {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.STREAMING_CHAT_MODEL));
needsStreamingChatModelBean = true;
}
} else {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.CHAT_MODEL),
AnnotationInstance.builder(ModelName.class).add("value", chatModelName).build());

if (injectStreamingChatModelBean) {
if (needsStreamingChatModel) {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.STREAMING_CHAT_MODEL),
AnnotationInstance.builder(ModelName.class).add("value", chatModelName).build());
needsStreamingChatModelBean = true;
Expand Down Expand Up @@ -515,7 +536,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
}

if (LangChain4jDotNames.BEAN_IF_EXISTS_MODERATION_MODEL_SUPPLIER.toString()
.equals(moderationModelSupplierClassName) && injectModerationModelBean) {
.equals(moderationModelSupplierClassName) && needsModerationModel) {

if (NamedConfigUtil.isDefault(moderationModelName)) {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.MODERATION_MODEL));
Expand All @@ -527,6 +548,15 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
needsModerationModelBean = true;
}

if (enableCache) {
if (LangChain4jDotNames.BEAN_AI_CACHE_PROVIDER_SUPPLIER.toString().equals(aiCacheProviderSupplierClassName)) {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.AI_CACHE_PROVIDER));
}
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.AI_CACHE_PROVIDER));
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.EMBEDDING_MODEL));
needsAiCacheProvider = true;
}

syntheticBeanProducer.produce(configurator.done());
}

Expand All @@ -551,6 +581,10 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
if (needsModerationModelBean) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.MODERATION_MODEL));
}
if (needsAiCacheProvider) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.AI_CACHE_PROVIDER));
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.EMBEDDING_MODEL));
}
if (!allToolNames.isEmpty()) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(allToolNames));
}
Expand Down Expand Up @@ -870,6 +904,8 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(MethodInfo method, boolea
}

boolean requiresModeration = method.hasAnnotation(LangChain4jDotNames.MODERATE);
boolean requiresCache = method.declaringClass().hasDeclaredAnnotation(LangChain4jDotNames.CACHE_RESULT)
|| method.hasDeclaredAnnotation(LangChain4jDotNames.CACHE_RESULT);

List<MethodParameterInfo> params = method.parameters();

Expand All @@ -887,7 +923,7 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(MethodInfo method, boolea
List<String> methodToolClassNames = gatherMethodToolClassNames(method);

return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo,
userMessageInfo, memoryIdParamPosition, requiresModeration,
userMessageInfo, memoryIdParamPosition, requiresModeration, requiresCache,
returnType, metricsTimedInfo, metricsCountedInfo, spanInfo, methodToolClassNames);
}

Expand Down Expand Up @@ -1222,6 +1258,16 @@ static Map<String, Integer> toNameToArgsPositionMap(List<TemplateParameterInfo>
}
}

private String getModelName(AnnotationValue value) {
if (value != null) {
String modelNameValueStr = value.asString();
if ((modelNameValueStr != null) && !modelNameValueStr.isEmpty()) {
return modelNameValueStr;
}
}
return NamedConfigUtil.DEFAULT_NAME;
}

public static final class AiServicesMethodBuildItem extends MultiBuildItem {

private final MethodInfo methodInfo;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ void indexDependencies(BuildProducer<IndexDependencyBuildItem> producer) {
}

@BuildStep
public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished,
public void handleProviders(
AiCacheBuildItem aiCacheBuildItem,
BeanDiscoveryFinishedBuildItem beanDiscoveryFinished,
List<ChatModelProviderCandidateBuildItem> chatCandidateItems,
List<EmbeddingModelProviderCandidateBuildItem> embeddingCandidateItems,
List<ModerationModelProviderCandidateBuildItem> moderationCandidateItems,
Expand Down Expand Up @@ -165,7 +167,8 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
}
}
// If the Easy RAG extension requested to automatically generate an embedding model...
if (requestEmbeddingModels.isEmpty() && autoCreateEmbeddingModelBuildItem.isPresent()) {
if (requestEmbeddingModels.isEmpty()
&& (aiCacheBuildItem.isEnable() || autoCreateEmbeddingModelBuildItem.isPresent())) {
String provider = selectEmbeddingModelProvider(inProcessEmbeddingBuildItems, embeddingCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(EmbeddingModel.class),
Optional.empty(), "EmbeddingModel", "embedding-model");
Expand Down
Loading

0 comments on commit 2e3d668

Please sign in to comment.