Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPIK-611] support gemini models in playground #987

Merged
merged 20 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions apps/opik-backend/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-anthropic</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-google-ai-gemini</artifactId>
</dependency>

<!-- Test -->

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.comet.opik;

import com.comet.opik.api.error.JsonInvalidFormatExceptionMapper;
import com.comet.opik.domain.llmproviders.LlmProviderClientModule;
import com.comet.opik.infrastructure.ConfigurationModule;
import com.comet.opik.infrastructure.EncryptionUtils;
import com.comet.opik.infrastructure.OpikConfiguration;
Expand Down Expand Up @@ -72,7 +73,7 @@ public void initialize(Bootstrap<OpikConfiguration> bootstrap) {
.withPlugins(new SqlObjectPlugin(), new Jackson2Plugin()))
.modules(new DatabaseAnalyticsModule(), new IdGeneratorModule(), new AuthModule(), new RedisModule(),
new RateLimitModule(), new NameGeneratorModule(), new HttpModule(), new EventModule(),
new ConfigurationModule(), new BiModule())
new ConfigurationModule(), new BiModule(), new LlmProviderClientModule())
.installers(JobGuiceyInstaller.class)
.listen(new OpikGuiceyLifecycleEventListener())
.enableAutoConfig()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
@RequiredArgsConstructor
public enum LlmProvider {
OPEN_AI("openai"),
ANTHROPIC("anthropic");
ANTHROPIC("anthropic"),
GEMINI("gemini"),
;

@JsonValue
private final String value;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.comet.opik.domain.llmproviders;

import dev.ai4j.openai4j.chat.ChatCompletionChoice;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.Delta;
import dev.ai4j.openai4j.chat.Role;
import dev.ai4j.openai4j.shared.Usage;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.output.Response;
import lombok.NonNull;

import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;

public record ChunkedResponseHandler(
@NonNull Consumer<ChatCompletionResponse> handleMessage,
@NonNull Runnable handleClose,
@NonNull Consumer<Throwable> handleError,
@NonNull String model) implements StreamingResponseHandler<AiMessage> {

@Override
public void onNext(@NonNull String content) {
handleMessage.accept(ChatCompletionResponse.builder()
.model(model)
.choices(List.of(ChatCompletionChoice.builder()
.delta(Delta.builder()
.content(content)
.role(Role.ASSISTANT)
.build())
.build()))
.build());
}

@Override
public void onComplete(@NonNull Response<AiMessage> response) {
handleMessage.accept(ChatCompletionResponse.builder()
.model(model)
.choices(List.of(ChatCompletionChoice.builder()
.delta(Delta.builder()
.content("")
.role(Role.ASSISTANT)
.build())
.build()))
.usage(Usage.builder()
.promptTokens(response.tokenUsage().inputTokenCount())
.completionTokens(response.tokenUsage().outputTokenCount())
.totalTokens(response.tokenUsage().totalTokenCount())
.build())
.id(Optional.ofNullable(response.metadata().get("id")).map(Object::toString).orElse(null))
.build());
handleClose.run();
}

@Override
public void onError(@NonNull Throwable throwable) {
handleError.accept(throwable);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.comet.opik.domain.llmproviders;

import lombok.RequiredArgsConstructor;

/*
Langchain4j doesn't provide gemini models enum.
This information is taken from: https://ai.google.dev/gemini-api/docs/models/gemini
*/
@RequiredArgsConstructor
public enum GeminiModelName {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure these values are as complete as possible and also aligned with the FrontEnd choices.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double checked against the gemini documentation (link in comments). Will update the FE ticket accordingly.

GEMINI_2_0_FLASH("gemini-2.0-flash-exp"),
GEMINI_1_5_FLASH("gemini-1.5-flash"),
GEMINI_1_5_FLASH_8B("gemini-1.5-flash-8b"),
GEMINI_1_5_PRO("gemini-1.5-pro"),
GEMINI_1_0_PRO("gemini-1.0-pro"),
TEXT_EMBEDDING("text-embedding-004"),
AQA("aqa");

private final String value;

@Override
public String toString() {
return value;
}
}
Original file line number Diff line number Diff line change
@@ -1,68 +1,33 @@
package com.comet.opik.domain.llmproviders;

import com.comet.opik.infrastructure.LlmProviderClientConfig;
import dev.ai4j.openai4j.chat.AssistantMessage;
import dev.ai4j.openai4j.chat.ChatCompletionChoice;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.Delta;
import dev.ai4j.openai4j.chat.Message;
import dev.ai4j.openai4j.chat.Role;
import dev.ai4j.openai4j.chat.SystemMessage;
import dev.ai4j.openai4j.chat.UserMessage;
import dev.ai4j.openai4j.shared.Usage;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.anthropic.internal.api.AnthropicContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
import dev.langchain4j.model.anthropic.internal.api.AnthropicMessage;
import dev.langchain4j.model.anthropic.internal.api.AnthropicMessageContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicRole;
import dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoice;
import dev.langchain4j.model.anthropic.internal.client.AnthropicClient;
import dev.langchain4j.model.anthropic.internal.client.AnthropicHttpException;
import dev.langchain4j.model.output.Response;
import io.dropwizard.jersey.errors.ErrorMessage;
import jakarta.ws.rs.BadRequestException;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;

import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;

import static com.comet.opik.domain.ChatCompletionService.ERROR_EMPTY_MESSAGES;
import static com.comet.opik.domain.ChatCompletionService.ERROR_NO_COMPLETION_TOKENS;

@RequiredArgsConstructor
@Slf4j
class LlmProviderAnthropic implements LlmProviderService {
private final @NonNull LlmProviderClientConfig llmProviderClientConfig;
private final @NonNull AnthropicClient anthropicClient;

public LlmProviderAnthropic(@NonNull LlmProviderClientConfig llmProviderClientConfig, @NonNull String apiKey) {
this.llmProviderClientConfig = llmProviderClientConfig;
this.anthropicClient = newClient(apiKey);
}

@Override
public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) {
var response = anthropicClient.createMessage(toAnthropicCreateMessageRequest(request));
var mapper = LlmProviderAnthropicMapper.INSTANCE;
var response = anthropicClient.createMessage(mapper.toCreateMessageRequest(request));

return ChatCompletionResponse.builder()
.id(response.id)
.model(response.model)
.choices(response.content.stream().map(content -> toChatCompletionChoice(response, content))
.toList())
.usage(Usage.builder()
.promptTokens(response.usage.inputTokens)
.completionTokens(response.usage.outputTokens)
.totalTokens(response.usage.inputTokens + response.usage.outputTokens)
.build())
.build();
return mapper.toResponse(response);
}

@Override
Expand All @@ -72,7 +37,7 @@ public void generateStream(
@NonNull Consumer<ChatCompletionResponse> handleMessage,
@NonNull Runnable handleClose, @NonNull Consumer<Throwable> handleError) {
validateRequest(request);
anthropicClient.createMessage(toAnthropicCreateMessageRequest(request),
anthropicClient.createMessage(LlmProviderAnthropicMapper.INSTANCE.toCreateMessageRequest(request),
new ChunkedResponseHandler(handleMessage, handleClose, handleError, request.model()));
}

Expand All @@ -88,151 +53,12 @@ public void validateRequest(@NonNull ChatCompletionRequest request) {
}

@Override
public @NonNull Optional<ErrorMessage> getLlmProviderError(Throwable runtimeException) {
public Optional<ErrorMessage> getLlmProviderError(@NonNull Throwable runtimeException) {
if (runtimeException instanceof AnthropicHttpException anthropicHttpException) {
return Optional.of(new ErrorMessage(anthropicHttpException.statusCode(),
anthropicHttpException.getMessage()));
}

return Optional.empty();
}

private AnthropicCreateMessageRequest toAnthropicCreateMessageRequest(ChatCompletionRequest request) {
var builder = AnthropicCreateMessageRequest.builder();
Optional.ofNullable(request.toolChoice())
.ifPresent(toolChoice -> builder.toolChoice(AnthropicToolChoice.from(
request.toolChoice().toString())));
return builder
.stream(request.stream())
.model(request.model())
.messages(request.messages().stream()
.filter(message -> List.of(Role.ASSISTANT, Role.USER).contains(message.role()))
.map(this::toMessage).toList())
.system(request.messages().stream()
.filter(message -> message.role() == Role.SYSTEM)
.map(this::toSystemMessage).toList())
.temperature(request.temperature())
.topP(request.topP())
.stopSequences(request.stop())
.maxTokens(request.maxCompletionTokens())
.build();
}

private AnthropicMessage toMessage(Message message) {
if (message.role() == Role.ASSISTANT) {
return AnthropicMessage.builder()
.role(AnthropicRole.ASSISTANT)
.content(List.of(new AnthropicTextContent(((AssistantMessage) message).content())))
.build();
}

if (message.role() == Role.USER) {
return AnthropicMessage.builder()
.role(AnthropicRole.USER)
.content(List.of(toAnthropicMessageContent(((UserMessage) message).content())))
.build();
}

throw new BadRequestException("unexpected message role: " + message.role());
}

private AnthropicTextContent toSystemMessage(Message message) {
if (message.role() != Role.SYSTEM) {
throw new BadRequestException("expecting only system role, got: " + message.role());
}

return new AnthropicTextContent(((SystemMessage) message).content());
}

private AnthropicMessageContent toAnthropicMessageContent(Object rawContent) {
if (rawContent instanceof String content) {
return new AnthropicTextContent(content);
}

throw new BadRequestException("only text content is supported");
}

private ChatCompletionChoice toChatCompletionChoice(
AnthropicCreateMessageResponse response, AnthropicContent content) {
return ChatCompletionChoice.builder()
.message(AssistantMessage.builder()
.name(content.name)
.content(content.text)
.build())
.finishReason(response.stopReason)
.build();
}

private AnthropicClient newClient(String apiKey) {
var anthropicClientBuilder = AnthropicClient.builder();
Optional.ofNullable(llmProviderClientConfig.getAnthropicClient())
.map(LlmProviderClientConfig.AnthropicClientConfig::url)
.ifPresent(url -> {
if (StringUtils.isNotEmpty(url)) {
anthropicClientBuilder.baseUrl(url);
}
});
Optional.ofNullable(llmProviderClientConfig.getAnthropicClient())
.map(LlmProviderClientConfig.AnthropicClientConfig::version)
.ifPresent(version -> {
if (StringUtils.isNotBlank(version)) {
anthropicClientBuilder.version(version);
}
});
Optional.ofNullable(llmProviderClientConfig.getLogRequests())
.ifPresent(anthropicClientBuilder::logRequests);
Optional.ofNullable(llmProviderClientConfig.getLogResponses())
.ifPresent(anthropicClientBuilder::logResponses);
// anthropic client builder only receives one timeout variant
Optional.ofNullable(llmProviderClientConfig.getCallTimeout())
.ifPresent(callTimeout -> anthropicClientBuilder.timeout(callTimeout.toJavaDuration()));
return anthropicClientBuilder
.apiKey(apiKey)
.build();
}

private record ChunkedResponseHandler(
Consumer<ChatCompletionResponse> handleMessage,
Runnable handleClose,
Consumer<Throwable> handleError,
String model) implements StreamingResponseHandler<AiMessage> {

@Override
public void onNext(String s) {
handleMessage.accept(ChatCompletionResponse.builder()
.model(model)
.choices(List.of(ChatCompletionChoice.builder()
.delta(Delta.builder()
.content(s)
.role(Role.ASSISTANT)
.build())
.build()))
.build());
}

@Override
public void onComplete(Response<AiMessage> response) {
handleMessage.accept(ChatCompletionResponse.builder()
.model(model)
.choices(List.of(ChatCompletionChoice.builder()
.delta(Delta.builder()
.content("")
.role(Role.ASSISTANT)
.build())
.build()))
.usage(Usage.builder()
.promptTokens(response.tokenUsage().inputTokenCount())
.completionTokens(response.tokenUsage().outputTokenCount())
.totalTokens(response.tokenUsage().totalTokenCount())
.build())
.id((String) response.metadata().get("id"))
.build());
handleClose.run();
}

@Override
public void onError(Throwable throwable) {
handleError.accept(throwable);
}
}
}
Loading
Loading