Skip to content

Commit

Permalink
Fix SetFunctionCall if has zero tools (#160)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunshineplan authored Dec 10, 2024
1 parent ac737ef commit 8675070
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 2 deletions.
4 changes: 4 additions & 0 deletions ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ func testFunctionCall(t *testing.T, model string, c ai.AI) {
c.SetModel(model)
}
c.SetTemperature(0)
c.SetFunctionCall(nil, ai.FunctionCallingAuto)
if _, err := c.Chat(context.Background(), ai.Text("Which theaters in Mountain View show Barbie movie?")); err != nil {
t.Fatal(err)
}
movieChat := func(t *testing.T, s ai.Schema, fcm ai.FunctionCallingMode) {
movieTool := ai.Function{
Name: "find_theaters",
Expand Down
5 changes: 4 additions & 1 deletion anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ func (ai *Anthropic) wait(ctx context.Context) error {

func (ai *Anthropic) SetModel(model string) { ai.model = model }
func (a *Anthropic) SetFunctionCall(f []ai.Function, mode ai.FunctionCallingMode) {
a.tools = nil
if a.tools = nil; len(f) == 0 {
a.toolChoice = nil
return
}
for _, i := range f {
a.tools = append(a.tools, anthropic.ToolParam{
Name: anthropic.String(i.Name),
Expand Down
5 changes: 4 additions & 1 deletion chatgpt/chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ func (ai *ChatGPT) wait(ctx context.Context) error {

func (ai *ChatGPT) SetModel(model string) { ai.model = model }
func (chatgpt *ChatGPT) SetFunctionCall(f []ai.Function, mode ai.FunctionCallingMode) {
chatgpt.tools = nil
if chatgpt.tools = nil; len(f) == 0 {
chatgpt.toolChoice = nil
return
}
for _, i := range f {
var parameters shared.FunctionParameters
b, _ := json.Marshal(i.Parameters)
Expand Down
5 changes: 5 additions & 0 deletions gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ func (ai *Gemini) SetModel(model string) {
}

func (gemini *Gemini) SetFunctionCall(f []ai.Function, mode ai.FunctionCallingMode) {
if len(f) == 0 {
gemini.model.Tools = nil
gemini.model.ToolConfig = nil
return
}
var declarations []*genai.FunctionDeclaration
for _, i := range f {
p, err := genaiProperties(i.Parameters.Properties)
Expand Down

0 comments on commit 8675070

Please sign in to comment.