Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasVitale committed Oct 31, 2024
1 parent 9d08bef commit 1c22ee6
Show file tree
Hide file tree
Showing 32 changed files with 295 additions and 221 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ String chat(String question) {
}

@GetMapping("/chat/generic-options")
String chatWithGenericOptions(String question) {
String chatGenericOptions(String question) {
return chatClient
.prompt(question)
.options(ChatOptionsBuilder.builder()
Expand All @@ -41,7 +41,7 @@ String chatWithGenericOptions(String question) {
}

@GetMapping("/chat/provider-options")
String chatWithProviderOptions(String question) {
String chatProviderOptions(String question) {
return chatClient
.prompt(question)
.options(MistralAiChatOptions.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ String chat(String question) {
}

@GetMapping("/chat/generic-options")
String chatWithGenericOptions(String question) {
String chatGenericOptions(String question) {
return chatModel.call(new Prompt(question, ChatOptionsBuilder.builder()
.withModel(MistralAiApi.ChatModel.OPEN_MIXTRAL_7B.getName())
.withTemperature(0.9)
Expand All @@ -38,7 +38,7 @@ String chatWithGenericOptions(String question) {
}

@GetMapping("/chat/provider-options")
String chatWithProviderOptions(String question) {
String chatProviderOptions(String question) {
return chatModel.call(new Prompt(question, MistralAiChatOptions.builder()
.withSafePrompt(true)
.build()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ String chatOpenAi(String question) {
}

@GetMapping("/chat/mistral-ai-options")
String chatWithMistralAiOptions(String question) {
String chatMistralAiOptions(String question) {
return mistralAichatClient
.prompt(question)
.options(MistralAiChatOptions.builder()
Expand All @@ -53,7 +53,7 @@ String chatWithMistralAiOptions(String question) {
}

@GetMapping("/chat/openai-options")
String chatWithOpenAiOptions(String question) {
String chatOpenAiOptions(String question) {
return openAichatClient
.prompt(question)
.options(OpenAiChatOptions.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ String chatOpenAi(String question) {
}

@GetMapping("/chat/mistral-ai-options")
String chatWithMistralAiOptions(String question) {
String chatMistralAiOptions(String question) {
return mistralAiChatModel.call(new Prompt(question, MistralAiChatOptions.builder()
.withModel(MistralAiApi.ChatModel.OPEN_MIXTRAL_7B.getValue())
.withTemperature(1.0)
Expand All @@ -46,7 +46,7 @@ String chatWithMistralAiOptions(String question) {
}

@GetMapping("/chat/openai-options")
String chatWithOpenAiOptions(String question) {
String chatOpenAiOptions(String question) {
return openAiChatModel.call(new Prompt(question, OpenAiChatOptions.builder()
.withModel(OpenAiApi.ChatModel.GPT_4_O_MINI.getValue())
.withTemperature(1.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ String chat(String question) {
}

@GetMapping("/chat/generic-options")
String chatWithGenericOptions(String question) {
String chatGenericOptions(String question) {
return chatClient
.prompt(question)
.options(ChatOptionsBuilder.builder()
Expand All @@ -40,7 +40,7 @@ String chatWithGenericOptions(String question) {
}

@GetMapping("/chat/provider-options")
String chatWithProviderOptions(String question) {
String chatProviderOptions(String question) {
return chatClient
.prompt(question)
.options(OllamaOptions.builder()
Expand All @@ -51,7 +51,7 @@ String chatWithProviderOptions(String question) {
}

@GetMapping("/chat/huggingface")
String chatWithHuggingFace(String question) {
String chatHuggingFace(String question) {
return chatClient
.prompt(question)
.options(ChatOptionsBuilder.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ String chat(String question) {
}

@GetMapping("/chat/generic-options")
String chatWithGenericOptions(String question) {
String chatGenericOptions(String question) {
return chatModel.call(new Prompt(question, ChatOptionsBuilder.builder()
.withModel("llama3.2:1b")
.withTemperature(0.9)
Expand All @@ -37,15 +37,15 @@ String chatWithGenericOptions(String question) {
}

@GetMapping("/chat/provider-options")
String chatWithProviderOptions(String question) {
String chatProviderOptions(String question) {
return chatModel.call(new Prompt(question, OllamaOptions.builder()
.withRepeatPenalty(1.5)
.build()))
.getResult().getOutput().getContent();
}

@GetMapping("/chat/huggingface")
String chatWithHuggingFace(String question) {
String chatHuggingFace(String question) {
return chatModel.call(new Prompt(question, ChatOptionsBuilder.builder()
.withModel("hf.co/SanctumAI/Llama-3.2-1B-Instruct-GGUF")
.build()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ String chat(String question) {
}

@GetMapping("/chat/generic-options")
String chatWithGenericOptions(String question) {
String chatGenericOptions(String question) {
return chatClient
.prompt(question)
.options(ChatOptionsBuilder.builder()
Expand All @@ -41,7 +41,7 @@ String chatWithGenericOptions(String question) {
}

@GetMapping("/chat/provider-options")
String chatWithProviderOptions(String question) {
String chatProviderOptions(String question) {
return chatClient
.prompt(question)
.options(OpenAiChatOptions.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ String chat(String question) {
}

@GetMapping("/chat/generic-options")
String chatWithGenericOptions(String question) {
String chatGenericOptions(String question) {
return chatModel.call(new Prompt(question, ChatOptionsBuilder.builder()
.withModel(OpenAiApi.ChatModel.GPT_4_O_MINI.getValue())
.withTemperature(0.9)
Expand All @@ -38,7 +38,7 @@ String chatWithGenericOptions(String question) {
}

@GetMapping("/chat/provider-options")
String chatWithProviderOptions(String question) {
String chatProviderOptions(String question) {
return chatModel.call(new Prompt(question, OpenAiChatOptions.builder()
.withLogprobs(true)
.build()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public ChatController(ChatClient.Builder chatClientBuilder) {
}

@GetMapping("/chat/image/file")
String chatFromImageFile(String question) {
String chatImageFile(String question) {
return chatClient.prompt()
.user(userSpec -> userSpec
.text(question)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ChatModelController {
}

@GetMapping("/chat/image/file")
String chatFromImageFile(String question) {
String chatImageFile(String question) {
var userMessage = new UserMessage(question, new Media(MimeTypeUtils.IMAGE_PNG, image));
var prompt = new Prompt(userMessage);
var chatResponse = chatModel.call(prompt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public ChatController(ChatClient.Builder chatClientBuilder) {
}

@GetMapping("/chat/image/file")
String chatFromImageFile(String question) {
String chatImageFile(String question) {
return chatClient.prompt()
.user(userSpec -> userSpec
.text(question)
Expand All @@ -38,7 +38,7 @@ String chatFromImageFile(String question) {
}

@GetMapping("/chat/image/url")
String chatFromImageUrl(String question) throws MalformedURLException {
String chatImageUrl(String question) throws MalformedURLException {
var imageUrl = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png";
var url = URI.create(imageUrl).toURL();
return chatClient.prompt()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ class ChatModelController {
}

@GetMapping("/chat/image/file")
String chatFromImageFile(String question) {
String chatImageFile(String question) {
var userMessage = new UserMessage(question, new Media(MimeTypeUtils.IMAGE_PNG, image));
var prompt = new Prompt(userMessage);
var chatResponse = chatModel.call(prompt);
return chatResponse.getResult().getOutput().getContent();
}

@GetMapping("/chat/image/url")
String chatFromImageUrl(String question) throws MalformedURLException {
String chatImageUrl(String question) throws MalformedURLException {
var imageUrl = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png";
var url = URI.create(imageUrl).toURL();

Expand Down
55 changes: 30 additions & 25 deletions 06-embedding-models/embedding-models-mistral-ai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,56 +9,61 @@ Spring AI provides an `EmbeddingModel` abstraction for integrating with LLMs via
When using the _Spring AI Mistral AI Spring Boot Starter_, an `EmbeddingModel` object is autoconfigured for you to use Mistral AI.

```java
@RestController
class EmbeddingController {
private final EmbeddingModel embeddingModel;

EmbeddingController(EmbeddingModel embeddingModel) {
this.embeddingModel = embeddingModel;
}

@GetMapping("/embed")
String embed(@RequestParam(defaultValue = "And Gandalf yelled: 'You shall not pass!'") String message) {
var embeddings = embeddingModel.embed(message);
return "Size of the embedding vector: " + embeddings.size();
}
@Bean
CommandLineRunner embed(EmbeddingModel embeddingModel) {
return _ -> {
var embeddings = embeddingModel.embed("And Gandalf yelled: 'You shall not pass!'");
System.out.println("Size of the embedding vector: " + embeddings.length);
};
}
```

## Running the application
## Mistral AI

The application relies on the Mistral AI API for providing LLMs.

First, make sure you have a [Mistral AI account](https://console.mistral.ai).
Then, define an environment variable with the Mistral AI API Key associated to your Mistral AI account as the value.
### Create a Mistral AI account

Visit [https://console.mistral.ai](console.mistral.ai) and sign up for a new account.
You can choose the "Experiment" plan, which gives you access to the Mistral APIs for free.

### Configure API Key

In the Mistral AI console, navigate to _API Keys_ and generate a new API key.
Copy and securely store your API key on your machine as an environment variable.
The application will use it to access the Mistral AI API.

```shell
export SPRING_AI_MISTRALAI_API_KEY=<INSERT KEY HERE>
export MISTRALAI_API_KEY=<YOUR-API-KEY>
```

Finally, run the Spring Boot application.
## Running the application

Run the application.

```shell
./gradlew bootRun
```

## Calling the application

You can now call the application that will use Mistral AI and _mistral-embed_ to generate a vector representation (embeddings) of a default text.
This example uses [httpie](https://httpie.io) to send HTTP requests.
> [!NOTE]
> These examples use the [httpie](https://httpie.io) CLI to send HTTP requests.
Call the application that will use an embedding model to generate embeddings for your query.

```shell
http :8080/embed
http :8080/embed query=="The capital of Italy is Rome"
```

Try passing your custom prompt and check the result.
The next request is configured with generic portable options.

```shell
http :8080/embed message=="The capital of Italy is Rome"
http :8080/embed/generic-options query=="The capital of Italy is Rome" -b
```

The next request is configured with Mistral AI-specific customizations.
The next request is configured with the provider's specific options.

```shell
http :8080/embed/mistral-ai-options message=="The capital of Italy is Rome"
http :8080/embed/provider-options query=="The capital of Italy is Rome" -b
```
5 changes: 2 additions & 3 deletions 06-embedding-models/embedding-models-mistral-ai/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ dependencies {
implementation 'org.springframework.boot:spring-boot-starter-web'
implementation 'org.springframework.ai:spring-ai-mistral-ai-spring-boot-starter'

testAndDevelopmentOnly 'org.springframework.boot:spring-boot-devtools'
developmentOnly 'org.springframework.boot:spring-boot-devtools'

testImplementation 'org.springframework.boot:spring-boot-starter-test'
testImplementation 'org.springframework.boot:spring-boot-testcontainers'
testImplementation 'org.testcontainers:junit-jupiter'
testRuntimeOnly 'org.junit.platform:junit-platform-launcher'
}

tasks.named('test') {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package com.thomasvitale.ai.spring;

import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.mistralai.MistralAiEmbeddingOptions;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

import java.util.List;
Expand All @@ -19,15 +20,24 @@ class EmbeddingController {
}

@GetMapping("/embed")
String embed(@RequestParam(defaultValue = "And Gandalf yelled: 'You shall not pass!'") String message) {
var embeddings = embeddingModel.embed(message);
String embed(String query) {
var embeddings = embeddingModel.embed(query);
return "Size of the embedding vector: " + embeddings.length;
}

@GetMapping("/embed/mistral-ai-options")
String embedWithMistralAiOptions(@RequestParam(defaultValue = "And Gandalf yelled: 'You shall not pass!'") String message) {
var embeddings = embeddingModel.call(new EmbeddingRequest(List.of(message), MistralAiEmbeddingOptions.builder()
.withModel("mistral-embed")
@GetMapping("/embed/generic-options")
String embedGenericOptions(String query) {
var embeddings = embeddingModel.call(new EmbeddingRequest(List.of(query), EmbeddingOptionsBuilder.builder()
.withModel(MistralAiApi.EmbeddingModel.EMBED.getValue())
.build()))
.getResult().getOutput();
return "Size of the embedding vector: " + embeddings.length;
}

@GetMapping("/embed/provider-options")
String embedProviderOptions(String query) {
var embeddings = embeddingModel.call(new EmbeddingRequest(List.of(query), MistralAiEmbeddingOptions.builder()
.withEncodingFormat("float")
.build()))
.getResult().getOutput();
return "Size of the embedding vector: " + embeddings.length;
Expand Down
Loading

0 comments on commit 1c22ee6

Please sign in to comment.