Skip to content

Commit

Permalink
New Image and Blob Part type (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunshineplan authored Jan 7, 2025
1 parent a7bd410 commit f17d9b8
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 21 deletions.
2 changes: 1 addition & 1 deletion ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func testChatSession(model string, c ai.AI) error {
case ai.Text:
fmt.Println(i.Role, ":", v)
case ai.Image:
fmt.Printf("%s : [%s]", i.Role, v.MIMEType)
fmt.Printf("%s : [%s]", i.Role, v.MIMEType())
}
}
}
Expand Down
11 changes: 8 additions & 3 deletions anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ func (resp *ChatResponse[Response]) String() string {
}

func toImageBlock(img ai.Image) anthropic.ImageBlockParam {
return anthropic.NewImageBlockBase64(img.MIMEType, base64.StdEncoding.EncodeToString(img.Data))
mime, data := img.Data()
return anthropic.NewImageBlockBase64(mime, base64.StdEncoding.EncodeToString(data))
}

func fromImageBlockSource(src anthropic.ImageBlockParamSource) ai.Image {
Expand Down Expand Up @@ -285,7 +286,9 @@ func (c *Anthropic) createRequest(
case ai.Text:
msgs = append(msgs, anthropic.NewUserMessage(anthropic.NewTextBlock(string(v))))
case ai.Image:
msgs = append(msgs, anthropic.NewUserMessage(toImageBlock((v))))
msgs = append(msgs, anthropic.NewUserMessage(toImageBlock(v)))
case ai.Blob:
msgs = append(msgs, anthropic.NewUserMessage(toImageBlock(ai.ImageData(v.MIMEType, v.Data))))
case ai.FunctionResponse:
msgs = append(msgs, anthropic.NewUserMessage(anthropic.NewToolResultBlock(v.ID, v.Response, false)))
}
Expand Down Expand Up @@ -420,7 +423,9 @@ func (session *ChatSession) addUserHistory(messages ...ai.Part) {
case ai.Text:
session.history = append(session.history, anthropic.NewUserMessage(anthropic.NewTextBlock(string(v))))
case ai.Image:
session.history = append(session.history, anthropic.NewUserMessage(toImageBlock((v))))
session.history = append(session.history, anthropic.NewUserMessage(toImageBlock(v)))
case ai.Blob:
session.history = append(session.history, anthropic.NewUserMessage(toImageBlock(ai.ImageData(v.MIMEType, v.Data))))
case ai.FunctionResponse:
session.history = append(session.history, anthropic.NewUserMessage(anthropic.NewToolResultBlock(v.ID, v.Response, false)))
}
Expand Down
18 changes: 7 additions & 11 deletions chatgpt/chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@ package chatgpt

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"net/url"
"regexp"
"strings"
"time"

Expand Down Expand Up @@ -217,16 +214,11 @@ func (resp *ChatResponse[Response]) String() string {
}

func toImagePart(img ai.Image) openai.ChatCompletionContentPartImageParam {
return openai.ImagePart(fmt.Sprintf("data:%s;base64,%s", img.MIMEType, base64.StdEncoding.EncodeToString(img.Data)))
return openai.ImagePart(string(img))
}

func fromImagePart(img openai.ChatCompletionContentPartImageParam) ai.Image {
res := regexp.MustCompile("^data:(.+);base64,(.+)$").FindStringSubmatch(img.ImageURL.Value.URL.Value)
b, err := base64.StdEncoding.DecodeString(res[1])
if err != nil {
panic(err)
}
return ai.ImageData(res[0], b)
return ai.Image(img.ImageURL.Value.URL.Value)
}

func (c *ChatGPT) createRequest(
Expand Down Expand Up @@ -275,7 +267,9 @@ func (c *ChatGPT) createRequest(
case ai.Text:
msgs = append(msgs, openai.UserMessage(string(v)))
case ai.Image:
msgs = append(msgs, openai.UserMessageParts(toImagePart((v))))
msgs = append(msgs, openai.UserMessageParts(toImagePart(v)))
case ai.Blob:
msgs = append(msgs, openai.UserMessageParts(toImagePart(ai.ImageData(v.MIMEType, v.Data))))
case ai.FunctionResponse:
msgs = append(msgs, openai.ToolMessage(v.ID, v.Response))
}
Expand Down Expand Up @@ -403,6 +397,8 @@ func (session *ChatSession) addUserHistory(messages ...ai.Part) {
session.history = append(session.history, openai.UserMessage(string(v)))
case ai.Image:
session.history = append(session.history, openai.UserMessageParts(toImagePart(v)))
case ai.Blob:
session.history = append(session.history, openai.UserMessageParts(toImagePart(ai.ImageData(v.MIMEType, v.Data))))
case ai.FunctionResponse:
session.history = append(session.history, openai.ToolMessage(v.ID, v.Response))
}
Expand Down
69 changes: 64 additions & 5 deletions content.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
package ai

import (
"encoding/base64"
"io"
"net/http"
"net/url"
"strings"
)

type Content struct {
Parts []Part
Role string
Expand All @@ -13,20 +21,71 @@ type Text string

func (Text) implementsPart() {}

type Image struct {
type Blob struct {
MIMEType string
Data []byte
}

func (Blob) implementsPart() {}

type Image string

func ImageData(mime string, data []byte) Image {
return Image{
MIMEType: mime,
Data: data,
}
return Image("data:" + mime + ";base64," + base64.StdEncoding.EncodeToString(data))
}

func (Image) implementsPart() {}

func (img Image) MIMEType() string {
u, err := url.Parse(string(img))
if err != nil {
panic(err)
}
switch u.Scheme {
case "data":
mime, _, _ := strings.Cut(u.Opaque, ";")
return mime
case "http", "https":
resp, err := http.Head(u.String())
if err != nil {
panic(err)
}
defer resp.Body.Close()
return resp.Header.Get("Content-Type")
default:
panic("unsupported image scheme: " + u.Scheme)
}
}

func (img Image) Data() (mime string, data []byte) {
u, err := url.Parse(string(img))
if err != nil {
panic(err)
}
switch u.Scheme {
case "data":
mime, b64, _ := strings.Cut(u.Opaque, ";base64,")
b, err := base64.StdEncoding.DecodeString(b64)
if err != nil {
panic(err)
}
return mime, b
case "http", "https":
resp, err := http.Get(u.String())
if err != nil {
panic(err)
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
panic(err)
}
return resp.Header.Get("Content-Type"), b
default:
panic("unsupported image scheme: " + u.Scheme)
}
}

type FunctionCall struct {
ID string
Name string
Expand Down
5 changes: 4 additions & 1 deletion gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ func toParts(src []ai.Part) (dst []genai.Part) {
case ai.Text:
dst = append(dst, genai.Text(v))
case ai.Image:
mime, data := v.Data()
dst = append(dst, genai.Blob{MIMEType: mime, Data: data})
case ai.Blob:
dst = append(dst, genai.Blob(v))
case ai.FunctionCall:
b, err := json.Marshal(v.Arguments)
Expand Down Expand Up @@ -194,7 +197,7 @@ func fromParts(src []genai.Part) (dst []ai.Part) {
case genai.Text:
dst = append(dst, ai.Text(v))
case genai.Blob:
dst = append(dst, ai.Image(v))
dst = append(dst, ai.Blob(v))
case genai.FunctionCall:
b, err := json.Marshal(v.Args)
if err != nil {
Expand Down

0 comments on commit f17d9b8

Please sign in to comment.