diff --git a/ai_test.go b/ai_test.go index cac71e0..5ce5e43 100644 --- a/ai_test.go +++ b/ai_test.go @@ -95,6 +95,7 @@ func TestGemini(t *testing.T) { return } gemini, err := gemini.New( + context.Background(), ai.WithAPIKey(apiKey), ai.WithEndpoint(os.Getenv("GEMINI_ENDPOINT")), ai.WithProxy(os.Getenv("GEMINI_PROXY")), diff --git a/client/client.go b/client/client.go index 4fd268f..2ee8d72 100644 --- a/client/client.go +++ b/client/client.go @@ -1,6 +1,7 @@ package client import ( + "context" "errors" "github.com/sunshineplan/ai" @@ -26,7 +27,7 @@ func New(cfg ai.ClientConfig) (client ai.AI, err error) { case ai.ChatGPT: client, err = chatgpt.New(opts...) case ai.Gemini: - client, err = gemini.New(opts...) + client, err = gemini.New(context.Background(), opts...) default: err = errors.New("unknown LLMs: " + string(cfg.LLMs)) } diff --git a/gemini/gemini.go b/gemini/gemini.go index ac0f833..306fe44 100644 --- a/gemini/gemini.go +++ b/gemini/gemini.go @@ -28,7 +28,7 @@ type Gemini struct { limiter *rate.Limiter } -func New(opts ...ai.ClientOption) (ai.AI, error) { +func New(ctx context.Context, opts ...ai.ClientOption) (ai.AI, error) { cfg := new(ai.ClientConfig) for _, i := range opts { i.Apply(cfg) @@ -50,7 +50,7 @@ func New(opts ...ai.ClientOption) (ai.AI, error) { if cfg.Endpoint != "" { o = append(o, option.WithEndpoint(cfg.Endpoint)) } - client, err := genai.NewClient(context.Background(), o...) + client, err := genai.NewClient(ctx, o...) if err != nil { return nil, err } diff --git a/prompt/prompt.go b/prompt/prompt.go index 2e326ec..28a373c 100644 --- a/prompt/prompt.go +++ b/prompt/prompt.go @@ -12,7 +12,7 @@ import ( ) const ( - defaultTimeout = time.Minute + defaultTimeout = 5 * time.Minute defaultWorkers = 3 ) @@ -144,10 +144,10 @@ func (prompt *Prompt) Execute(ai ai.AI, input []string, prefix string) (<-chan * return c, n, nil } -func (prompt *Prompt) JobList(ai ai.AI, input []string, prefix string, c chan<- *Result) (*workers.JobList[*Result], error) { +func (prompt *Prompt) JobList(ai ai.AI, input []string, prefix string, c chan<- *Result) (*workers.JobList[*Result], int, error) { prompts, err := prompt.Prompts(input, prefix) if err != nil { - return nil, err + return nil, 0, err } jobList := workers.NewJobList(workers.NewWorkers(prompt.workers), func(r *Result) { resp, err := chat(ai, prompt.d, r.Prompt) @@ -163,7 +163,7 @@ func (prompt *Prompt) JobList(ai ai.AI, input []string, prefix string, c chan<- for i, p := range prompts { jobList.PushBack(&Result{Index: i, Prompt: p}) } - return jobList, nil + return jobList, len(prompts), nil } func chat(ai ai.AI, d time.Duration, p string) (ai.ChatResponse, error) {