Skip to content

Commit

Permalink
prompt: Add newWorkers according to ai.Limit() (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunshineplan authored Jul 12, 2024
1 parent b3654ac commit fdbdaa0
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 28 deletions.
13 changes: 11 additions & 2 deletions chatgpt/chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package chatgpt
import (
"context"
"io"
"math"
"net/http"
"net/url"
"time"

"github.com/sunshineplan/ai"

Expand Down Expand Up @@ -74,8 +76,15 @@ func (chatgpt *ChatGPT) Model(_ context.Context) (string, error) {
return chatgpt.model, nil
}

func (chatgpt *ChatGPT) SetLimit(limit rate.Limit) {
chatgpt.limiter = ai.NewLimiter(limit)
func (chatgpt *ChatGPT) SetLimit(rpm int64) {
chatgpt.limiter = ai.NewLimiter(rpm)
}

func (chatgpt *ChatGPT) Limit() (rpm int64) {
if chatgpt.limiter == nil {
return math.MaxInt64
}
return int64(chatgpt.limiter.Limit() / rate.Every(time.Minute))
}

func (ai *ChatGPT) wait(ctx context.Context) error {
Expand Down
10 changes: 4 additions & 6 deletions config.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
package ai

import "golang.org/x/time/rate"

type ClientConfig struct {
LLMs LLMs

APIKey string
Endpoint string
Proxy string

Limit *rate.Limit
Limit *int64

Model string
ModelConfig ModelConfig
Expand Down Expand Up @@ -48,7 +46,7 @@ type ClientOption interface {
func WithAPIKey(apiKey string) ClientOption { return withAPIKey(apiKey) }
func WithEndpoint(endpoint string) ClientOption { return withEndpoint(endpoint) }
func WithProxy(proxy string) ClientOption { return withProxy(proxy) }
func WithLimit(limit rate.Limit) ClientOption { return withLimit(limit) }
func WithLimit(rpm int64) ClientOption { return withLimit(rpm) }
func WithModel(model string) ClientOption { return withModel(model) }
func WithModelConfig(config ModelConfig) ClientOption { return withModelConfig(config) }

Expand All @@ -64,9 +62,9 @@ type withProxy string

func (w withProxy) Apply(cfg *ClientConfig) { cfg.Proxy = string(w) }

type withLimit rate.Limit
type withLimit int64

func (w withLimit) Apply(cfg *ClientConfig) { cfg.Limit = (*rate.Limit)(&w) }
func (w withLimit) Apply(cfg *ClientConfig) { cfg.Limit = (*int64)(&w) }

type withModel string

Expand Down
13 changes: 11 additions & 2 deletions gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ import (
"context"
"errors"
"io"
"math"
"net/http"
"net/url"
"strings"
"time"

"github.com/sunshineplan/ai"

Expand Down Expand Up @@ -81,8 +83,15 @@ func (gemini *Gemini) Model(ctx context.Context) (string, error) {
return info.Name, nil
}

func (gemini *Gemini) SetLimit(limit rate.Limit) {
gemini.limiter = ai.NewLimiter(limit)
func (gemini *Gemini) SetLimit(rpm int64) {
gemini.limiter = ai.NewLimiter(rpm)
}

func (gemini *Gemini) Limit() (rpm int64) {
if gemini.limiter == nil {
return math.MaxInt64
}
return int64(gemini.limiter.Limit() / rate.Every(time.Minute))
}

func (ai *Gemini) wait(ctx context.Context) error {
Expand Down
10 changes: 6 additions & 4 deletions limiter.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
package ai

import (
"math"
"time"

"golang.org/x/time/rate"
)

type Limiter interface {
SetLimit(rate.Limit)
SetLimit(rpm int64)
Limit() (rpm int64)
}

func NewLimiter(limit rate.Limit) *rate.Limiter {
if limit == rate.Inf {
func NewLimiter(rpm int64) *rate.Limiter {
if rpm == math.MaxInt64 {
return nil
}
return rate.NewLimiter(rate.Every(time.Minute)*limit, int(limit))
return rate.NewLimiter(rate.Every(time.Minute)*rate.Limit(rpm), int(rpm))
}
27 changes: 13 additions & 14 deletions prompt/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package prompt
import (
"context"
"fmt"
"math"
"strings"
"text/template"
"time"
Expand All @@ -11,10 +12,7 @@ import (
"github.com/sunshineplan/workers"
)

const (
defaultTimeout = 5 * time.Minute
defaultWorkers = 3
)
const defaultTimeout = 3 * time.Minute

const defaultTemplate = `{{.Request}}{{if .Example}}
###
Expand Down Expand Up @@ -49,12 +47,11 @@ type Prompt struct {
ex *Example
n int

d time.Duration
workers int64
d time.Duration
}

func New(prompt string) *Prompt {
p := &Prompt{prompt: prompt, d: defaultTimeout, workers: defaultWorkers}
p := &Prompt{prompt: prompt, d: defaultTimeout}
p.t = template.Must(template.New("prompt").Funcs(defaultFuncMap).Parse(defaultTemplate))
return p
}
Expand All @@ -79,11 +76,6 @@ func (prompt *Prompt) SetAITimeout(d time.Duration) *Prompt {
return prompt
}

func (prompt *Prompt) SetWorkers(n int64) *Prompt {
prompt.workers = n
return prompt
}

func (prompt *Prompt) Prompts(input []string, prefix string) (prompts []string, err error) {
length := len(input)
if length == 0 {
Expand Down Expand Up @@ -123,6 +115,13 @@ type Result struct {
Error error
}

func newWorkers(ai ai.AI) *workers.Workers {
if rpm := ai.Limit(); rpm != math.MaxInt64 {
return workers.NewWorkers(rpm)
}
return workers.NewWorkers(0)
}

func (prompt *Prompt) Execute(ai ai.AI, input []string, prefix string) (<-chan *Result, int, error) {
prompts, err := prompt.Prompts(input, prefix)
if err != nil {
Expand All @@ -131,7 +130,7 @@ func (prompt *Prompt) Execute(ai ai.AI, input []string, prefix string) (<-chan *
n := len(prompts)
c := make(chan *Result, n)
go func() {
workers.NewWorkers(prompt.workers).Run(context.Background(), workers.SliceJob(prompts, func(i int, p string) {
newWorkers(ai).Run(context.Background(), workers.SliceJob(prompts, func(i int, p string) {
resp, err := chat(ai, prompt.d, p)
if err != nil {
c <- &Result{i, p, nil, 0, err}
Expand All @@ -150,7 +149,7 @@ func (prompt *Prompt) JobList(ctx context.Context, ai ai.AI, input []string, pre
if err != nil {
return nil, 0, err
}
jobList := workers.NewJobList(workers.NewWorkers(prompt.workers), func(r *Result) {
jobList := workers.NewJobList(newWorkers(ai), func(r *Result) {
resp, err := chat(ai, prompt.d, r.Prompt)
if err != nil {
r.Result = nil
Expand Down

0 comments on commit fdbdaa0

Please sign in to comment.