forked from phuongdo/catboost-go
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathregression.go
37 lines (33 loc) · 906 Bytes
/
regression.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
package catboost
// Regression is a wrapper over model object that adds methods for catboost regression
type Regression struct {
Model *Model
}
func LoadRegressionFromFile(filename string) (*Regression, error) {
model, err := LoadFullModelFromFile(filename)
if err != nil {
return nil, err
}
return &Regression{Model: model}, nil
}
func (r *Regression) PredictRegression(
floats [][]float32, floatLength int,
cats [][]string, catLength int,
texts [][]string, textLength int,
embeddings [][][]float32, embeddingDimensions []int, embeddingSize int,
) ([]float64, error) {
results, err := r.Model.CalcModelPredictionTextAndEmbeddings(
floats, floatLength,
cats, catLength,
texts, textLength,
embeddings, embeddingDimensions, embeddingSize,
)
if err != nil {
return nil, err
}
return results, nil
}
// Close deletes model handler
func (r *Regression) Close() {
r.Model.Close()
}