Skip to content

Commit

Permalink
Merge branch 'main' into databricks
Browse files Browse the repository at this point in the history
  • Loading branch information
t0mpl authored Jan 16, 2025
2 parents 842b1c0 + 71ded3c commit 9965b52
Show file tree
Hide file tree
Showing 16 changed files with 415 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ jobs:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GENAI_API_KEY: ${{ secrets.GENAI_API_KEY }}
run: go test -v ./...
run: go test -v ./... -race

3 changes: 3 additions & 0 deletions chains/chains_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type testLanguageModel struct {
simulateWork time.Duration
// record the prompt that was passed to the language model
recordedPrompt []llms.PromptValue
mu sync.Mutex
}

type stringPromptValue struct {
Expand All @@ -46,9 +47,11 @@ func (l *testLanguageModel) GenerateContent(_ context.Context, mc []llms.Message
} else {
return nil, fmt.Errorf("passed non-text part")
}
l.mu.Lock()
l.recordedPrompt = []llms.PromptValue{
stringPromptValue{s: prompt},
}
l.mu.Unlock()

if l.simulateWork > 0 {
time.Sleep(l.simulateWork)
Expand Down
22 changes: 22 additions & 0 deletions examples/openai-jsonformat-example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# OpenAI STRUCTURED JSON Example

This example demonstrates how to use the OpenAI model with the LangChain Go library to generate structured JSON output.

## What does this example do?

This nifty little program does the following:

1. Sets up a connection to the OpenAI API using the GPT-4o model.

2. Prompts the AI to generate a structured JSON output.


## How to Run

1. Make sure you have Go installed on your system.
2. Set up your OpenAI API key as an environment variable.
3. Run the program:

```
go run openai_jsonformat.go
```
11 changes: 11 additions & 0 deletions examples/openai-jsonformat-example/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module github.com/tmc/langchaingo/examples/openai-jsonformat-example

go 1.23

require github.com/tmc/langchaingo v0.1.13-pre.0.0.20250106145851-f1fde1f9e4a0

require (
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/pkoukk/tiktoken-go v0.1.6 // indirect
)
58 changes: 58 additions & 0 deletions examples/openai-jsonformat-example/openai-jsonformat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package main

import (
"context"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/openai"
"log"
)

type User struct {
Name string `json:"name"`
Age int `json:"age"`
}

func main() {
fomat := &openai.ResponseFormat{
Type: "json_schema",
JSONSchema: &openai.ResponseFormatJSONSchema{
Name: "object",
Schema: &openai.ResponseFormatJSONSchemaProperty{
Type: "object",
Properties: map[string]*openai.ResponseFormatJSONSchemaProperty{
"name": {
Type: "string",
Description: "The name of the user",
},
"age": {
Type: "integer",
Description: "The age of the user",
},
"role": {
Type: "string",
Description: "The role of the user",
},
},
AdditionalProperties: false,
Required: []string{"name", "age", "role"},
},
Strict: true,
},
}
llm, err := openai.New(openai.WithModel("gpt-4o"), openai.WithResponseFormat(fomat))
if err != nil {
log.Fatal(err)
}
ctx := context.Background()

content := []llms.MessageContent{
llms.TextParts(llms.ChatMessageTypeSystem, "You are an expert at structured data extraction. You will be given unstructured text from a research paper and should convert it into the given structure."),
llms.TextParts(llms.ChatMessageTypeHuman, "please tell me the most famous people in history"),
}

completion, err := llm.GenerateContent(ctx, content, llms.WithJSONMode())
if err != nil {
log.Fatal(err)
}
log.Fatal(completion.Choices[0].Content)
}
19 changes: 16 additions & 3 deletions llms/ernie/internal/ernieclient/ernieclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"log"
"net/http"
"strings"
"sync"
"time"
)

Expand All @@ -26,6 +27,7 @@ type Client struct {
apiKey string
secretKey string
accessToken string
mu sync.RWMutex
httpClient Doer
Model string
ModelPath ModelPath
Expand Down Expand Up @@ -175,7 +177,9 @@ func autoRefresh(c *Client) error {
time.Sleep(tryPeriod * time.Minute) // try
continue
}
c.mu.Lock()
c.accessToken = authResp.AccessToken
c.mu.Unlock()
time.Sleep(10 * 24 * time.Hour)
}
}()
Expand All @@ -188,8 +192,11 @@ func (c *Client) CreateCompletion(ctx context.Context, modelPath ModelPath, r *C
modelPath = DefaultCompletionModelPath
}

c.mu.RLock()
accessToken := c.accessToken
c.mu.RUnlock()
url := "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + string(modelPath) +
"?access_token=" + c.accessToken
"?access_token=" + accessToken
body, e := json.Marshal(r)
if e != nil {
return nil, e
Expand Down Expand Up @@ -219,8 +226,11 @@ func (c *Client) CreateCompletion(ctx context.Context, modelPath ModelPath, r *C

// CreateEmbedding use ernie Embedding-V1.
func (c *Client) CreateEmbedding(ctx context.Context, texts []string) (*EmbeddingResponse, error) {
c.mu.RLock()
accessToken := c.accessToken
c.mu.RUnlock()
url := "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1?access_token=" +
c.accessToken
accessToken

payload := make(map[string]any)
payload["input"] = texts
Expand Down Expand Up @@ -346,8 +356,11 @@ func (c *Client) buildURL(modelpath ModelPath) string {

// ernie example url:
// /wenxinworkshop/chat/eb-instant
c.mu.RLock()
accessToken := c.accessToken
c.mu.RUnlock()
return fmt.Sprintf("%s/wenxinworkshop/chat/%s?access_token=%s",
baseURL, modelpath, c.accessToken,
baseURL, modelpath, accessToken,
)
}

Expand Down
54 changes: 43 additions & 11 deletions llms/ernie/internal/ernieclient/ernieclient_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,30 @@
package ernieclient

import "testing"
import (
"encoding/json"
"io"
"net/http"
"strings"
"testing"
)

type mockHTTPClient struct{}

// implement ernieclient.Doer interface.
func (m *mockHTTPClient) Do(_ *http.Request) (*http.Response, error) {
authResponse := &authResponse{
AccessToken: "test",
}
b, err := json.Marshal(authResponse)
if err != nil {
return nil, err
}
response := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(string(b))),
}
return response, nil
}

func TestClient_buildURL(t *testing.T) {
t.Parallel()
Expand All @@ -9,8 +33,6 @@ func TestClient_buildURL(t *testing.T) {
secretKey string
accessToken string
httpClient Doer
Model string
ModelPath ModelPath
}
type args struct {
modelpath ModelPath
Expand All @@ -22,25 +44,35 @@ func TestClient_buildURL(t *testing.T) {
want string
}{
{
name: "one",
name: "with access token",
fields: fields{
accessToken: "token",
},
args: args{modelpath: "eb-instant"},
want: "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=token",
},
{
name: "with client, aksk",
fields: fields{
apiKey: "test",
secretKey: "test",
httpClient: &mockHTTPClient{},
},
args: args{modelpath: "eb-instant"},
want: "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=test",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
c := &Client{
apiKey: tt.fields.apiKey,
secretKey: tt.fields.secretKey,
accessToken: tt.fields.accessToken,
httpClient: tt.fields.httpClient,
Model: tt.fields.Model,
ModelPath: tt.fields.ModelPath,
c, err := New(
WithAKSK(tt.fields.apiKey, tt.fields.secretKey),
WithAccessToken(tt.fields.accessToken),
WithHTTPClient(tt.fields.httpClient),
)
if err != nil {
t.Errorf("New got error. %v", err)
}
if got := c.buildURL(tt.args.modelpath); got != tt.want {
t.Errorf("buildURL() = %v, want %v", got, tt.want)
Expand Down
3 changes: 0 additions & 3 deletions llms/openai/internal/openaiclient/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,6 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatCom

// Build request
body := bytes.NewReader(payloadBytes)
if c.baseURL == "" {
c.baseURL = defaultBaseURL
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL("/chat/completions", payload.Model), body)
if err != nil {
return nil, err
Expand Down
3 changes: 3 additions & 0 deletions llms/openai/internal/openaiclient/openaiclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ func New(token string, model string, baseURL string, organization string,
httpClient: httpClient,
ResponseFormat: responseFormat,
}
if c.baseURL == "" {
c.baseURL = defaultBaseURL
}

for _, opt := range opts {
if err := opt(c); err != nil {
Expand Down
16 changes: 12 additions & 4 deletions textsplitter/markdown_splitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"
"reflect"
"strings"
"unicode/utf8"

"gitlab.com/golang-commonmark/markdown"
)
Expand All @@ -25,6 +24,7 @@ func NewMarkdownTextSplitter(opts ...Option) *MarkdownTextSplitter {
ReferenceLinks: options.ReferenceLinks,
HeadingHierarchy: options.KeepHeadingHierarchy,
JoinTableRows: options.JoinTableRows,
LenFunc: options.LenFunc,
}

if sp.SecondSplitter == nil {
Expand All @@ -36,6 +36,7 @@ func NewMarkdownTextSplitter(opts ...Option) *MarkdownTextSplitter {
"\n", // new line
" ", // space
}),
WithLenFunc(options.LenFunc),
)
}

Expand All @@ -57,6 +58,7 @@ type MarkdownTextSplitter struct {
ReferenceLinks bool
HeadingHierarchy bool
JoinTableRows bool
LenFunc func(string) int
}

// SplitText splits a text into multiple text.
Expand All @@ -76,6 +78,7 @@ func (sp MarkdownTextSplitter) SplitText(text string) ([]string, error) {
joinTableRows: sp.JoinTableRows,
hTitleStack: []string{},
hTitlePrependHierarchy: sp.HeadingHierarchy,
lenFunc: sp.LenFunc,
}

chunks := mc.splitText()
Expand Down Expand Up @@ -133,6 +136,9 @@ type markdownContext struct {
// joinTableRows determines whether a chunk should contain multiple table rows,
// or if each row in a table should be split into a separate chunk.
joinTableRows bool

// lenFunc represents the function to calculate the length of a string.
lenFunc func(string) int
}

// splitText splits Markdown text.
Expand Down Expand Up @@ -193,6 +199,8 @@ func (mc *markdownContext) clone(startAt, endAt int) *markdownContext {
chunkSize: mc.chunkSize,
chunkOverlap: mc.chunkOverlap,
secondSplitter: mc.secondSplitter,

lenFunc: mc.lenFunc,
}
}

Expand Down Expand Up @@ -438,7 +446,7 @@ func (mc *markdownContext) splitTableRows(header []string, bodies [][]string) {
// If we're at the start of the current snippet, or adding the current line would
// overflow the chunk size, prepend the header to the line (so that the new chunk
// will include the table header).
if len(mc.curSnippet) == 0 || utf8.RuneCountInString(mc.curSnippet)+utf8.RuneCountInString(line) >= mc.chunkSize {
if len(mc.curSnippet) == 0 || mc.lenFunc(mc.curSnippet+line) >= mc.chunkSize {
line = fmt.Sprintf("%s\n%s", headerMD, line)
}

Expand Down Expand Up @@ -617,7 +625,7 @@ func (mc *markdownContext) joinSnippet(snippet string) {
}

// check whether current chunk exceeds chunk size, if so, apply to chunks
if utf8.RuneCountInString(mc.curSnippet)+utf8.RuneCountInString(snippet) >= mc.chunkSize {
if mc.lenFunc(mc.curSnippet+snippet) >= mc.chunkSize {
mc.applyToChunks()
mc.curSnippet = snippet
} else {
Expand All @@ -634,7 +642,7 @@ func (mc *markdownContext) applyToChunks() {
var chunks []string
if mc.curSnippet != "" {
// check whether current chunk is over ChunkSize,if so, re-split current chunk
if utf8.RuneCountInString(mc.curSnippet) <= mc.chunkSize+mc.chunkOverlap {
if mc.lenFunc(mc.curSnippet) <= mc.chunkSize+mc.chunkOverlap {
chunks = []string{mc.curSnippet}
} else {
// split current snippet to chunks
Expand Down
Loading

0 comments on commit 9965b52

Please sign in to comment.