Skip to content

Commit

Permalink
ollama: Fix JSON format bug issue when not streaming (#892)
Browse files Browse the repository at this point in the history
* Graceful handling when LLM spits whitespace on json mode with Ollama.

* ollama: Simplify stream repr, spruce up fn calling example

---------

Co-authored-by: Travis Cline <travis.cline@gmail.com>
  • Loading branch information
doslindos and tmc authored Jun 16, 2024
1 parent e380482 commit c6b8f4f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 19 deletions.
33 changes: 17 additions & 16 deletions examples/ollama-functions-example/ollama_functions_example.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"log"
"os"
Expand All @@ -12,10 +13,13 @@ import (
"github.com/tmc/langchaingo/llms/ollama"
)

var flagVerbose = flag.Bool("v", false, "verbose mode")

func main() {
flag.Parse()
// allow specifying your own model via OLLAMA_TEST_MODEL
// (same as the Ollama unit tests).
model := "mistral:instruct"
model := "llama3"
if v := os.Getenv("OLLAMA_TEST_MODEL"); v != "" {
model = v
}
Expand All @@ -31,14 +35,12 @@ func main() {
var msgs []llms.MessageContent

// system message defines the available tools.
msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeSystem,
systemMessage()))
msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeHuman,
"What's the weather like in Beijing?"))
msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeSystem, systemMessage()))
msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeHuman, "What's the weather like in Beijing?"))

ctx := context.Background()

for {
for retries := 3; retries > 0; retries = retries - 1 {
resp, err := llm.GenerateContent(ctx, msgs)
if err != nil {
log.Fatal(err)
Expand All @@ -49,19 +51,23 @@ func main() {

if c := unmarshalCall(choice1.Content); c != nil {
log.Printf("Call: %v", c.Tool)

if *flagVerbose {
log.Printf("Call: %v (raw: %v)", c.Tool, choice1.Content)
}
msg, cont := dispatchCall(c)
if !cont {
break
}

msgs = append(msgs, msg)
} else {
// Ollama doesn't always respond with a function call, let it try again.
log.Printf("Not a call: %v", choice1.Content)

msgs = append(msgs, llms.TextParts(llms.ChatMessageTypeHuman, "Sorry, I don't understand. Please try again."))
}

if retries == 0 {
log.Fatal("retries exhausted")
}
}
}

Expand All @@ -72,20 +78,17 @@ type Call struct {

func unmarshalCall(input string) *Call {
var c Call

if err := json.Unmarshal([]byte(input), &c); err == nil && c.Tool != "" {
return &c
}

return nil
}

func dispatchCall(c *Call) (llms.MessageContent, bool) {
// ollama doesn't always respond with a *valid* function call. As we're using prompt
// engineering to inject the tools, it may hallucinate.
if !validTool(c.Tool) {
log.Printf("invalid function call: %#v", c)

log.Printf("invalid function call: %#v, prompting model to try again", c)
return llms.TextParts(llms.ChatMessageTypeHuman,
"Tool does not exist, please try again."), true
}
Expand All @@ -106,7 +109,7 @@ func dispatchCall(c *Call) (llms.MessageContent, bool) {
if err != nil {
log.Fatal(err)
}
return llms.TextParts(llms.ChatMessageTypeSystem, weather), true
return llms.TextParts(llms.ChatMessageTypeHuman, weather), true
case "finalResponse":
resp, ok := c.Input["response"].(string)
if !ok {
Expand All @@ -124,11 +127,9 @@ func dispatchCall(c *Call) (llms.MessageContent, bool) {

func validTool(name string) bool {
var valid []string

for _, v := range functions {
valid = append(valid, v.Name)
}

return slices.Contains(valid, name)
}

Expand Down
2 changes: 1 addition & 1 deletion llms/ollama/internal/ollamaclient/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type Message struct {
type ChatRequest struct {
Model string `json:"model"`
Messages []*Message `json:"messages"`
Stream *bool `json:"stream,omitempty"`
Stream bool `json:"stream,omitempty"`
Format string `json:"format"`
KeepAlive string `json:"keep_alive,omitempty"`

Expand Down
4 changes: 2 additions & 2 deletions llms/ollama/ollamallm.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
Format: format,
Messages: chatMsgs,
Options: ollamaOptions,
Stream: func(b bool) *bool { return &b }(opts.StreamingFunc != nil),
Stream: opts.StreamingFunc != nil,
}

keepAlive := o.options.keepAlive
Expand All @@ -129,7 +129,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
if response.Message != nil {
streamedResponse += response.Message.Content
}
if response.Done {
if !req.Stream || response.Done {
resp = response
resp.Message = &ollamaclient.Message{
Role: "assistant",
Expand Down

0 comments on commit c6b8f4f

Please sign in to comment.