Skip to content

Commit

Permalink
Merge pull request #10 from pepperoni21/images-support
Browse files Browse the repository at this point in the history
Added images support for multimodal models
  • Loading branch information
pepperoni21 authored Dec 14, 2023
2 parents 5c2b5e3 + 9f98e55 commit d7131fb
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 5 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ rustls = ["reqwest/rustls-tls"]
[dev-dependencies]
tokio = { version = "1", features = ["full"] }
ollama-rs = { path = ".", features = ["stream"] }
base64 = "0.21.5"
1 change: 1 addition & 0 deletions src/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ pub mod chat;
pub mod completion;
pub mod embeddings;
pub mod format;
pub mod images;
pub mod options;
23 changes: 22 additions & 1 deletion src/generation/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ pub mod request;

use request::ChatMessageRequest;

use super::images::Image;

#[cfg(feature = "stream")]
/// A stream of `ChatMessageResponse` objects
pub type ChatMessageResponseStream =
Expand Down Expand Up @@ -123,11 +125,16 @@ pub struct ChatMessageFinalResponseData {
pub struct ChatMessage {
pub role: MessageRole,
pub content: String,
pub images: Option<Vec<Image>>,
}

impl ChatMessage {
pub fn new(role: MessageRole, content: String) -> Self {
Self { role, content }
Self {
role,
content,
images: None,
}
}

pub fn user(content: String) -> Self {
Expand All @@ -141,6 +148,20 @@ impl ChatMessage {
pub fn system(content: String) -> Self {
Self::new(MessageRole::System, content)
}

pub fn with_images(mut self, images: Vec<Image>) -> Self {
self.images = Some(images);
self
}

pub fn add_image(mut self, image: Image) -> Self {
if let Some(images) = self.images.as_mut() {
images.push(image);
} else {
self.images = Some(vec![image]);
}
self
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down
16 changes: 15 additions & 1 deletion src/generation/completion/request.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use serde::Serialize;

use crate::generation::{format::FormatType, options::GenerationOptions};
use crate::generation::{format::FormatType, images::Image, options::GenerationOptions};

use super::GenerationContext;

Expand All @@ -10,6 +10,7 @@ pub struct GenerationRequest {
#[serde(rename = "model")]
pub model_name: String,
pub prompt: String,
pub images: Vec<Image>,
pub options: Option<GenerationOptions>,
pub system: Option<String>,
pub template: Option<String>,
Expand All @@ -23,6 +24,7 @@ impl GenerationRequest {
Self {
model_name,
prompt,
images: Vec::new(),
options: None,
system: None,
template: None,
Expand All @@ -33,6 +35,18 @@ impl GenerationRequest {
}
}

/// A list of images to be used with the prompt
pub fn images(mut self, images: Vec<Image>) -> Self {
self.images = images;
self
}

/// Add an image to be used with the prompt
pub fn add_image(mut self, image: Image) -> Self {
self.images.push(image);
self
}

/// Additional model parameters listed in the documentation for the Modelfile
pub fn options(mut self, options: GenerationOptions) -> Self {
self.options = Some(options);
Expand Down
10 changes: 10 additions & 0 deletions src/generation/images.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Image(String);

impl Image {
pub fn from_base64(base64: &str) -> Self {
Self(base64.to_string())
}
}
35 changes: 33 additions & 2 deletions tests/generation.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#![allow(unused_imports)]

use base64::Engine;
use ollama_rs::{
generation::completion::{request::GenerationRequest, GenerationResponseStream},
generation::{
completion::{request::GenerationRequest, GenerationResponseStream},
images::Image,
},
Ollama,
};
use tokio::io::AsyncWriteExt;
Expand Down Expand Up @@ -48,3 +51,31 @@ async fn test_generation() {
.unwrap();
dbg!(res);
}

const IMAGE_URL: &str = "https://images.pexels.com/photos/1054655/pexels-photo-1054655.jpeg";

#[tokio::test]
async fn test_generation_with_images() {
let ollama = Ollama::default();

let bytes = reqwest::get(IMAGE_URL)
.await
.unwrap()
.bytes()
.await
.unwrap();
let base64 = base64::engine::general_purpose::STANDARD.encode(&bytes);
let image = Image::from_base64(&base64);

let res = ollama
.generate(
GenerationRequest::new(
"llava:latest".to_string(),
"What can we see in this image?".to_string(),
)
.add_image(image),
)
.await
.unwrap();
dbg!(res);
}
35 changes: 34 additions & 1 deletion tests/send_chat_messages.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use base64::Engine;
use ollama_rs::{
generation::chat::{request::ChatMessageRequest, ChatMessage},
generation::{
chat::{request::ChatMessageRequest, ChatMessage},
images::Image,
},
Ollama,
};
use tokio_stream::StreamExt;
Expand Down Expand Up @@ -49,3 +53,32 @@ async fn test_send_chat_messages() {

assert!(res.done);
}

const IMAGE_URL: &str = "https://images.pexels.com/photos/1054655/pexels-photo-1054655.jpeg";

#[tokio::test]
async fn test_send_chat_messages_with_images() {
let ollama = Ollama::default();

let bytes = reqwest::get(IMAGE_URL)
.await
.unwrap()
.bytes()
.await
.unwrap();
let base64 = base64::engine::general_purpose::STANDARD.encode(&bytes);
let image = Image::from_base64(&base64);

let messages =
vec![ChatMessage::user("What can we see in this image?".to_string()).add_image(image)];
let res = ollama
.send_chat_messages(ChatMessageRequest::new(
"llava:latest".to_string(),
messages,
))
.await
.unwrap();
dbg!(&res);

assert!(res.done);
}

0 comments on commit d7131fb

Please sign in to comment.