-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathfine_tunes.go
128 lines (108 loc) · 4.16 KB
/
fine_tunes.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package openai
import (
"context"
"fmt"
"net/http"
)
type FineTuneRequest struct {
TrainingFile string `json:"training_file"`
ValidationFile string `json:"validation_file,omitempty"`
Model string `json:"model,omitempty"`
NEpochs int `json:"n_epochs,omitempty"`
BatchSize int `json:"batch_size,omitempty"`
LearningRateMultiplier float32 `json:"learning_rate_multiplier,omitempty"`
PromptLossWeight float32 `json:"prompt_loss_weight,omitempty"`
ComputeClassificationMetrics bool `json:"compute_classification_metrics,omitempty"`
ClassificationNClasses int `json:"classification_n_classes,omitempty"`
ClassificationPositiveClass string `json:"classification_positive_class,omitempty"`
ClassificationBetas []float32 `json:"classification_betas,omitempty"`
Suffix string `json:"suffix,omitempty"`
}
type FineTune struct {
ID string `json:"id"`
Object string `json:"object"`
Model string `json:"model"`
CreatedAt int64 `json:"created_at"`
Events []FineTuneEvent `json:"events"`
FineTunedModel string `json:"fine_tuned_model"`
Hyperparams FineTuneHyperparams `json:"hyperparams"`
OrganizationID string `json:"organization_id"`
ResultFiles []File `json:"result_files"`
Status string `json:"status"`
ValidationFiles []File `json:"validation_files"`
TrainingFiles []File `json:"training_files"`
UpdatedAt int64 `json:"updated_at"`
}
type FineTuneEvent struct {
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
Level string `json:"level"`
Message string `json:"message"`
}
type FineTuneHyperparams struct {
BatchSize int64 `json:"batch_size"`
LearningRateMultiplier float64 `json:"learning_rate_multiplier"`
NEpochs int64 `json:"n_epochs"`
PromptLossWeight float64 `json:"prompt_loss_weight"`
}
type FineTuneList struct {
Object string `json:"object"`
Data []FineTune `json:"data"`
}
type FineTuneEventList struct {
Object string `json:"object"`
Data []FineTuneEvent `json:"data"`
}
type FineTuneDeleteResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Deleted bool `json:"deleted"`
}
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
req, err := c.requestFactory.Build(ctx, http.MethodPost, fullURL(fineTunes), request)
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}
func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) {
req, err := c.requestFactory.Build(ctx, http.MethodGet, fullURL(fineTunes), nil)
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}
func (c *Client) RetrieveFineTune(ctx context.Context, id string) (response FineTune, err error) {
req, err := c.requestFactory.Build(ctx, http.MethodGet, fmt.Sprintf("%s/%s", fullURL(fineTunes), id), nil)
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}
func (c *Client) CancelFineTune(ctx context.Context, id string) (response FineTune, err error) {
req, err := c.requestFactory.Build(ctx, http.MethodPost, fmt.Sprintf("%s/%s/cancel", fullURL(fineTunes), id), nil)
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}
func (c *Client) ListFineTuneEvents(ctx context.Context, id string) (response FineTuneEventList, err error) {
req, err := c.requestFactory.Build(ctx, http.MethodGet, fmt.Sprintf("%s/%s/events", fullURL(fineTunes), id), nil)
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}
func (c *Client) DeleteFineTuneModel(ctx context.Context, model string) (response FineTuneDeleteResponse, err error) {
req, err := c.requestFactory.Build(ctx, http.MethodDelete, fmt.Sprintf("%s/%s", fullURL(models), model), nil)
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}