Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate older commands, update random-vectors #43

Merged
merged 2 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions benchmarker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@ Usage:
benchmarker [command]
Available Commands:
ann-benchmark Benchmark ANN Benchmark style hdf5 files (this is generally what you want to use)
dataset Benchmark vectors from an existing dataset
ann-benchmark Benchmark ANN Benchmark style datasets
help Help about any command
random-text Benchmark nearText searches
random-vectors Benchmark nearVector searches
random-vectors Benchmark random vector queries
raw Benchmark raw GraphQL queries
Flags:
Expand Down Expand Up @@ -129,7 +127,7 @@ go run . \
--dimensions 384 \
--queries 10000 \
--parallel 8 \
--api graphql \
--api grpc \
--limit 10
```

Expand Down Expand Up @@ -158,7 +156,7 @@ benchmarker \
--dimensions 384 \
--queries 10000 \
--parallel 8 \
--api graphql \
--api grpc \
--limit 10
```

2 changes: 1 addition & 1 deletion benchmarker/cmd/ann_benchmark.go
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,7 @@ func runQueries(cfg *Config, importTime time.Duration, testData [][]float32, nei

var annBenchmarkCommand = &cobra.Command{
Use: "ann-benchmark",
Short: "Benchmark ANN Benchmark style hdf5 files",
Short: "Benchmark ANN Benchmark style datasets",
Long: `Run a gRPC benchmark on an hdf5 file in the format of ann-benchmarks.com`,
Run: func(cmd *cobra.Command, args []string) {

Expand Down
5 changes: 3 additions & 2 deletions benchmarker/cmd/benchmark_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ func processQueueHttp(queue []QueryWithNeighbors, cfg *Config, c *http.Client, m
r := bytes.NewReader(query.Query)
before := time.Now()
var url string
origin := fmt.Sprintf("%s://%s", cfg.HttpScheme, cfg.HttpOrigin)
if cfg.API == "graphql" {
url = cfg.Origin + "/v1/graphql"
url = origin + "/v1/graphql"
} else if cfg.API == "rest" {
url = fmt.Sprintf("%s/v1/objects/%s/_search", cfg.Origin, cfg.ClassName)
url = fmt.Sprintf("%s/v1/objects/%s/_search", origin, cfg.ClassName)
}
req, err := http.NewRequest("POST", url, r)
if err != nil {
Expand Down
6 changes: 1 addition & 5 deletions benchmarker/cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (c *Config) validateCommon() error {
}

switch c.API {
case "graphql", "nearvector", "grpc":
case "graphql", "grpc":
default:
return errors.Errorf("unsupported API %q", c.API)
}
Expand Down Expand Up @@ -128,10 +128,6 @@ func (c Config) validateRandomText() error {
}

func (c Config) validateRandomVectors() error {
if c.Dimensions == 0 {
return errors.Errorf("dimensions must be set and larger than 0\n")
}

return nil
}

Expand Down
15 changes: 5 additions & 10 deletions benchmarker/cmd/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ import (
)

var datasetCmd = &cobra.Command{
Use: "dataset",
Short: "Benchmark vectors from an existing dataset",
Long: `Specify an existing dataset as a list of query vectors in a .json file to parse the query vectors and then query them with the specified parallelism`,
Use: "dataset",
Short: "Benchmark vectors from an existing dataset",
Long: `Specify an existing dataset as a list of query vectors in a .json file to parse the query vectors and then query them with the specified parallelism`,
Deprecated: "This command is deprecated and will be removed in the future",
Run: func(cmd *cobra.Command, args []string) {
cfg := globalConfig
cfg.Mode = "dataset"
Expand Down Expand Up @@ -112,15 +113,9 @@ func benchmarkDataset(cfg Config, queries Queries) Results {
}
}

if cfg.API == "rest" {
return QueryWithNeighbors{
Query: nearVectorQueryJSONRest(cfg.ClassName, queries[i], cfg.Limit),
}
}

if cfg.API == "grpc" {
return QueryWithNeighbors{
Query: nearVectorQueryGrpc(&cfg, queries[i], cfg.Tenant, 0),
Query: nearVectorQueryGrpc(&cfg, queries[i], cfg.Tenant, -1),
}
}

Expand Down
1 change: 1 addition & 0 deletions benchmarker/cmd/random_text.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ var randomTextCmd = &cobra.Command{
Use: "random-text",
Short: "Benchmark nearText searches",
Long: `Benchmark random nearText searches`,
Deprecated: "This command is deprecated and will be removed in the future",
Run: func(cmd *cobra.Command, args []string) {
cfg := globalConfig

Expand Down
154 changes: 93 additions & 61 deletions benchmarker/cmd/random_vectors.go
Original file line number Diff line number Diff line change
@@ -1,52 +1,55 @@
package cmd

import (
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"math"
"math/rand"
"os"
"runtime"
"strconv"
"strings"
"time"

log "github.com/sirupsen/logrus"

"github.com/spf13/cobra"
"github.com/weaviate/weaviate-go-client/v4/weaviate"
weaviategrpc "github.com/weaviate/weaviate/grpc/generated/protocol/v1"

"google.golang.org/protobuf/proto"
)

func initRandomVectors() {
rootCmd.AddCommand(randomVectorsCmd)
numCPU := runtime.NumCPU()
randomVectorsCmd.PersistentFlags().IntVarP(&globalConfig.Queries,
"queries", "q", 100, "Set the number of queries the benchmarker should run")
randomVectorsCmd.PersistentFlags().IntVar(&globalConfig.QueryDuration,
"queryDuration", 0, "Instead of a fixed number of queries, query for the specified duration in seconds (default 0)")
randomVectorsCmd.PersistentFlags().IntVarP(&globalConfig.Parallel,
"parallel", "p", 8, "Set the number of parallel threads which send queries")
"parallel", "p", numCPU, "Set the number of parallel threads which send queries")
randomVectorsCmd.PersistentFlags().StringVarP(&globalConfig.API,
"api", "a", "grpc", "API (graphql | grpc) default and recommended is grpc")
randomVectorsCmd.PersistentFlags().IntVarP(&globalConfig.Limit,
"limit", "l", 10, "Set the query limit (top_k)")
randomVectorsCmd.PersistentFlags().IntVarP(&globalConfig.Dimensions,
"dimensions", "d", 768, "Set the vector dimensions (must match your data)")
randomVectorsCmd.PersistentFlags().StringVarP(&globalConfig.WhereFilter,
"where", "w", "", "An entire where filter as a string")
"dimensions", "d", 0, "Set the vector dimensions (will infer from class if not set)")
randomVectorsCmd.PersistentFlags().StringVarP(&globalConfig.ClassName,
"className", "c", "", "The Weaviate class to run the benchmark against")
randomVectorsCmd.PersistentFlags().StringVar(&globalConfig.DB,
"db", "weaviate", "The tool you're benchmarking")
randomVectorsCmd.PersistentFlags().StringVarP(&globalConfig.API,
"api", "a", "graphql", "The API to use on benchmarks")
randomVectorsCmd.PersistentFlags().StringVarP(&globalConfig.Origin,
"origin", "u", "http://localhost:8080", "The origin that Weaviate is running at")
randomVectorsCmd.PersistentFlags().StringVarP(&globalConfig.OutputFormat,
"format", "f", "text", "Output format, one of [text, json]")
randomVectorsCmd.PersistentFlags().StringVarP(&globalConfig.OutputFile,
"output", "o", "", "Filename for an output file. If none provided, output to stdout only")
"grpcOrigin", "u", "localhost:50051", "The gRPC origin that Weaviate is running at")
randomVectorsCmd.PersistentFlags().StringVar(&globalConfig.HttpOrigin,
"httpOrigin", "localhost:8080", "The http origin for Weaviate (without http scheme)")
randomVectorsCmd.PersistentFlags().StringVar(&globalConfig.HttpScheme,
"httpScheme", "http", "The http scheme (http or https)")
}

var randomVectorsCmd = &cobra.Command{
Use: "random-vectors",
Short: "Benchmark nearVector searches",
Long: `Benchmark random nearVector searches`,
Short: "Benchmark random vector queries",
Long: `Benchmark random vector queries`,
Run: func(cmd *cobra.Command, args []string) {
cfg := globalConfig
cfg.Mode = "random-vectors"
Expand All @@ -56,43 +59,49 @@ var randomVectorsCmd = &cobra.Command{
os.Exit(1)
}

if len(cfg.WhereFilter) > 0 {
filter := fmt.Sprintf(", where: { %s }", cfg.WhereFilter)
cfg.WhereFilter = strings.Replace(filter, "\"", "\\\"", -1)
}
log.WithFields(log.Fields{"queries": cfg.Queries,
"class": cfg.ClassName}).Info("Beginning random-vectors benchmark")

if cfg.DB == "weaviate" {
client := createClient(&cfg)
cfg.Dimensions = getDimensions(cfg, client)

var w io.Writer
if cfg.OutputFile == "" {
w = os.Stdout
} else {
f, err := os.Create(cfg.OutputFile)
if err != nil {
fatal(err)
}
var result Results

defer f.Close()
w = f
if cfg.QueryDuration > 0 {
result = benchmarkNearVectorDuration(cfg)
} else {
result = benchmarkNearVector(cfg)
}

}
log.WithFields(log.Fields{"mean": result.Mean, "qps": result.QueriesPerSecond,
"parallel": cfg.Parallel, "limit": cfg.Limit,
"api": cfg.API, "count": result.Total, "failed": result.Failed}).Info("Benchmark result")

result := benchmarkNearVector(cfg)
if cfg.OutputFormat == "json" {
result.WriteJSONTo(w)
} else if cfg.OutputFormat == "text" {
result.WriteTextTo(w)
}
},
}

if cfg.OutputFile != "" {
infof("results succesfully written to %q", cfg.OutputFile)
func getDimensions(cfg Config, client *weaviate.Client) int {
dimensions := cfg.Dimensions
if cfg.Dimensions == 0 {
// Try to infer dimensions from class

objects, err := client.Data().ObjectsGetter().WithClassName(cfg.ClassName).WithVector().WithLimit(10).Do(context.Background())
if err != nil {
log.Infof("Error fetching class %s, %v", cfg.ClassName, err)
}

for _, obj := range objects {
if obj.Vector != nil {
dimensions = len(obj.Vector)
break
}
return
}

fmt.Printf("unrecognized db\n")
os.Exit(1)
},
if dimensions == 0 {
log.Fatalf("Could not fetch dimensions from class %s", cfg.ClassName)
}
}
return dimensions
}

func randomVector(dims int) []float32 {
Expand All @@ -118,14 +127,6 @@ func nearVectorQueryJSONGraphQL(className string, vec []float32, limit int, wher
}`, className, limit, string(vecJSON), whereFilter))
}

func nearVectorQueryJSONRest(className string, vec []float32, limit int) []byte {
vecJSON, _ := json.Marshal(vec)
return []byte(fmt.Sprintf(`{
"nearVector":{"vector":%s},
"limit":%d
}`, string(vecJSON), limit))
}

func encodeVector(fs []float32) []byte {
buf := make([]byte, len(fs)*4)
for i, f := range fs {
Expand Down Expand Up @@ -190,19 +191,50 @@ func benchmarkNearVector(cfg Config) Results {
Query: nearVectorQueryJSONGraphQL(cfg.ClassName, randomVector(cfg.Dimensions), cfg.Limit, cfg.WhereFilter),
}
}

if cfg.API == "rest" {
return QueryWithNeighbors{
Query: nearVectorQueryJSONRest(cfg.ClassName, randomVector(cfg.Dimensions), cfg.Limit),
}
}

if cfg.API == "grpc" {
return QueryWithNeighbors{
Query: nearVectorQueryGrpc(&cfg, randomVector(cfg.Dimensions), cfg.Tenant, 0),
Query: nearVectorQueryGrpc(&cfg, randomVector(cfg.Dimensions), cfg.Tenant, -1),
}
}

return QueryWithNeighbors{}
})
}

func benchmarkNearVectorDuration(cfg Config) Results {

var samples sampledResults

startTime := time.Now()

var results Results
iterations := 0
for time.Since(startTime) < time.Duration(cfg.QueryDuration)*time.Second {
results = benchmarkNearVector(cfg)
samples.Min = append(samples.Min, results.Min)
samples.Max = append(samples.Max, results.Max)
samples.Mean = append(samples.Mean, results.Mean)
samples.Took = append(samples.Took, results.Took)
samples.QueriesPerSecond = append(samples.QueriesPerSecond, results.QueriesPerSecond)
samples.Results = append(samples.Results, results)
iterations += 1
}

var medianResult Results

medianResult.Min = time.Duration(median(samples.Min))
medianResult.Max = time.Duration(median(samples.Max))
medianResult.Mean = time.Duration(median(samples.Mean))
medianResult.Took = time.Duration(median(samples.Took))
medianResult.QueriesPerSecond = median(samples.QueriesPerSecond)
medianResult.Percentiles = results.Percentiles
medianResult.PercentilesLabels = results.PercentilesLabels
medianResult.Total = results.Total
medianResult.Successful = results.Successful
medianResult.Failed = results.Failed
medianResult.Parallelization = cfg.Parallel

log.WithFields(log.Fields{"iterations": iterations}).Infof("Queried for %d seconds", cfg.QueryDuration)

return medianResult
}