Skip to content

Commit

Permalink
Deprecate older commands, update random-vectors (#43)
Browse files Browse the repository at this point in the history
* Deprecate older commands, update random-vectors

* Fix description of option
  • Loading branch information
trengrj authored Dec 17, 2024
1 parent 2489a6b commit e4eced3
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 85 deletions.
10 changes: 4 additions & 6 deletions benchmarker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@ Usage:
benchmarker [command]
Available Commands:
ann-benchmark Benchmark ANN Benchmark style hdf5 files (this is generally what you want to use)
dataset Benchmark vectors from an existing dataset
ann-benchmark Benchmark ANN Benchmark style datasets
help Help about any command
random-text Benchmark nearText searches
random-vectors Benchmark nearVector searches
random-vectors Benchmark random vector queries
raw Benchmark raw GraphQL queries
Flags:
Expand Down Expand Up @@ -129,7 +127,7 @@ go run . \
--dimensions 384 \
--queries 10000 \
--parallel 8 \
--api graphql \
--api grpc \
--limit 10
```

Expand Down Expand Up @@ -158,7 +156,7 @@ benchmarker \
--dimensions 384 \
--queries 10000 \
--parallel 8 \
--api graphql \
--api grpc \
--limit 10
```

2 changes: 1 addition & 1 deletion benchmarker/cmd/ann_benchmark.go
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,7 @@ func runQueries(cfg *Config, importTime time.Duration, testData [][]float32, nei

var annBenchmarkCommand = &cobra.Command{
Use: "ann-benchmark",
Short: "Benchmark ANN Benchmark style hdf5 files",
Short: "Benchmark ANN Benchmark style datasets",
Long: `Run a gRPC benchmark on an hdf5 file in the format of ann-benchmarks.com`,
Run: func(cmd *cobra.Command, args []string) {

Expand Down
5 changes: 3 additions & 2 deletions benchmarker/cmd/benchmark_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ func processQueueHttp(queue []QueryWithNeighbors, cfg *Config, c *http.Client, m
r := bytes.NewReader(query.Query)
before := time.Now()
var url string
origin := fmt.Sprintf("%s://%s", cfg.HttpScheme, cfg.HttpOrigin)
if cfg.API == "graphql" {
url = cfg.Origin + "/v1/graphql"
url = origin + "/v1/graphql"
} else if cfg.API == "rest" {
url = fmt.Sprintf("%s/v1/objects/%s/_search", cfg.Origin, cfg.ClassName)
url = fmt.Sprintf("%s/v1/objects/%s/_search", origin, cfg.ClassName)
}
req, err := http.NewRequest("POST", url, r)
if err != nil {
Expand Down
6 changes: 1 addition & 5 deletions benchmarker/cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (c *Config) validateCommon() error {
}

switch c.API {
case "graphql", "nearvector", "grpc":
case "graphql", "grpc":
default:
return errors.Errorf("unsupported API %q", c.API)
}
Expand Down Expand Up @@ -128,10 +128,6 @@ func (c Config) validateRandomText() error {
}

func (c Config) validateRandomVectors() error {
if c.Dimensions == 0 {
return errors.Errorf("dimensions must be set and larger than 0\n")
}

return nil
}

Expand Down
15 changes: 5 additions & 10 deletions benchmarker/cmd/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ import (
)

var datasetCmd = &cobra.Command{
Use: "dataset",
Short: "Benchmark vectors from an existing dataset",
Long: `Specify an existing dataset as a list of query vectors in a .json file to parse the query vectors and then query them with the specified parallelism`,
Use: "dataset",
Short: "Benchmark vectors from an existing dataset",
Long: `Specify an existing dataset as a list of query vectors in a .json file to parse the query vectors and then query them with the specified parallelism`,
Deprecated: "This command is deprecated and will be removed in the future",
Run: func(cmd *cobra.Command, args []string) {
cfg := globalConfig
cfg.Mode = "dataset"
Expand Down Expand Up @@ -112,15 +113,9 @@ func benchmarkDataset(cfg Config, queries Queries) Results {
}
}

if cfg.API == "rest" {
return QueryWithNeighbors{
Query: nearVectorQueryJSONRest(cfg.ClassName, queries[i], cfg.Limit),
}
}

if cfg.API == "grpc" {
return QueryWithNeighbors{
Query: nearVectorQueryGrpc(&cfg, queries[i], cfg.Tenant, 0),
Query: nearVectorQueryGrpc(&cfg, queries[i], cfg.Tenant, -1),
}
}

Expand Down
1 change: 1 addition & 0 deletions benchmarker/cmd/random_text.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ var randomTextCmd = &cobra.Command{
Use: "random-text",
Short: "Benchmark nearText searches",
Long: `Benchmark random nearText searches`,
Deprecated: "This command is deprecated and will be removed in the future",
Run: func(cmd *cobra.Command, args []string) {
cfg := globalConfig

Expand Down
154 changes: 93 additions & 61 deletions benchmarker/cmd/random_vectors.go
Original file line number Diff line number Diff line change
@@ -1,52 +1,55 @@
package cmd

import (
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"math"
"math/rand"
"os"
"runtime"
"strconv"
"strings"
"time"

log "github.com/sirupsen/logrus"

"github.com/spf13/cobra"
"github.com/weaviate/weaviate-go-client/v4/weaviate"
weaviategrpc "github.com/weaviate/weaviate/grpc/generated/protocol/v1"

"google.golang.org/protobuf/proto"
)

func initRandomVectors() {
rootCmd.AddCommand(randomVectorsCmd)
numCPU := runtime.NumCPU()
randomVectorsCmd.PersistentFlags().IntVarP(&globalConfig.Queries,
"queries", "q", 100, "Set the number of queries the benchmarker should run")
randomVectorsCmd.PersistentFlags().IntVar(&globalConfig.QueryDuration,
"queryDuration", 0, "Instead of a fixed number of queries, query for the specified duration in seconds (default 0)")
randomVectorsCmd.PersistentFlags().IntVarP(&globalConfig.Parallel,
"parallel", "p", 8, "Set the number of parallel threads which send queries")
"parallel", "p", numCPU, "Set the number of parallel threads which send queries")
randomVectorsCmd.PersistentFlags().StringVarP(&globalConfig.API,
"api", "a", "grpc", "API (graphql | grpc) default and recommended is grpc")
randomVectorsCmd.PersistentFlags().IntVarP(&globalConfig.Limit,
"limit", "l", 10, "Set the query limit (top_k)")
randomVectorsCmd.PersistentFlags().IntVarP(&globalConfig.Dimensions,
"dimensions", "d", 768, "Set the vector dimensions (must match your data)")
randomVectorsCmd.PersistentFlags().StringVarP(&globalConfig.WhereFilter,
"where", "w", "", "An entire where filter as a string")
"dimensions", "d", 0, "Set the vector dimensions (will infer from class if not set)")
randomVectorsCmd.PersistentFlags().StringVarP(&globalConfig.ClassName,
"className", "c", "", "The Weaviate class to run the benchmark against")
randomVectorsCmd.PersistentFlags().StringVar(&globalConfig.DB,
"db", "weaviate", "The tool you're benchmarking")
randomVectorsCmd.PersistentFlags().StringVarP(&globalConfig.API,
"api", "a", "graphql", "The API to use on benchmarks")
randomVectorsCmd.PersistentFlags().StringVarP(&globalConfig.Origin,
"origin", "u", "http://localhost:8080", "The origin that Weaviate is running at")
randomVectorsCmd.PersistentFlags().StringVarP(&globalConfig.OutputFormat,
"format", "f", "text", "Output format, one of [text, json]")
randomVectorsCmd.PersistentFlags().StringVarP(&globalConfig.OutputFile,
"output", "o", "", "Filename for an output file. If none provided, output to stdout only")
"grpcOrigin", "u", "localhost:50051", "The gRPC origin that Weaviate is running at")
randomVectorsCmd.PersistentFlags().StringVar(&globalConfig.HttpOrigin,
"httpOrigin", "localhost:8080", "The http origin for Weaviate (without http scheme)")
randomVectorsCmd.PersistentFlags().StringVar(&globalConfig.HttpScheme,
"httpScheme", "http", "The http scheme (http or https)")
}

var randomVectorsCmd = &cobra.Command{
Use: "random-vectors",
Short: "Benchmark nearVector searches",
Long: `Benchmark random nearVector searches`,
Short: "Benchmark random vector queries",
Long: `Benchmark random vector queries`,
Run: func(cmd *cobra.Command, args []string) {
cfg := globalConfig
cfg.Mode = "random-vectors"
Expand All @@ -56,43 +59,49 @@ var randomVectorsCmd = &cobra.Command{
os.Exit(1)
}

if len(cfg.WhereFilter) > 0 {
filter := fmt.Sprintf(", where: { %s }", cfg.WhereFilter)
cfg.WhereFilter = strings.Replace(filter, "\"", "\\\"", -1)
}
log.WithFields(log.Fields{"queries": cfg.Queries,
"class": cfg.ClassName}).Info("Beginning random-vectors benchmark")

if cfg.DB == "weaviate" {
client := createClient(&cfg)
cfg.Dimensions = getDimensions(cfg, client)

var w io.Writer
if cfg.OutputFile == "" {
w = os.Stdout
} else {
f, err := os.Create(cfg.OutputFile)
if err != nil {
fatal(err)
}
var result Results

defer f.Close()
w = f
if cfg.QueryDuration > 0 {
result = benchmarkNearVectorDuration(cfg)
} else {
result = benchmarkNearVector(cfg)
}

}
log.WithFields(log.Fields{"mean": result.Mean, "qps": result.QueriesPerSecond,
"parallel": cfg.Parallel, "limit": cfg.Limit,
"api": cfg.API, "count": result.Total, "failed": result.Failed}).Info("Benchmark result")

result := benchmarkNearVector(cfg)
if cfg.OutputFormat == "json" {
result.WriteJSONTo(w)
} else if cfg.OutputFormat == "text" {
result.WriteTextTo(w)
}
},
}

if cfg.OutputFile != "" {
infof("results succesfully written to %q", cfg.OutputFile)
func getDimensions(cfg Config, client *weaviate.Client) int {
dimensions := cfg.Dimensions
if cfg.Dimensions == 0 {
// Try to infer dimensions from class

objects, err := client.Data().ObjectsGetter().WithClassName(cfg.ClassName).WithVector().WithLimit(10).Do(context.Background())
if err != nil {
log.Infof("Error fetching class %s, %v", cfg.ClassName, err)
}

for _, obj := range objects {
if obj.Vector != nil {
dimensions = len(obj.Vector)
break
}
return
}

fmt.Printf("unrecognized db\n")
os.Exit(1)
},
if dimensions == 0 {
log.Fatalf("Could not fetch dimensions from class %s", cfg.ClassName)
}
}
return dimensions
}

func randomVector(dims int) []float32 {
Expand All @@ -118,14 +127,6 @@ func nearVectorQueryJSONGraphQL(className string, vec []float32, limit int, wher
}`, className, limit, string(vecJSON), whereFilter))
}

func nearVectorQueryJSONRest(className string, vec []float32, limit int) []byte {
vecJSON, _ := json.Marshal(vec)
return []byte(fmt.Sprintf(`{
"nearVector":{"vector":%s},
"limit":%d
}`, string(vecJSON), limit))
}

func encodeVector(fs []float32) []byte {
buf := make([]byte, len(fs)*4)
for i, f := range fs {
Expand Down Expand Up @@ -190,19 +191,50 @@ func benchmarkNearVector(cfg Config) Results {
Query: nearVectorQueryJSONGraphQL(cfg.ClassName, randomVector(cfg.Dimensions), cfg.Limit, cfg.WhereFilter),
}
}

if cfg.API == "rest" {
return QueryWithNeighbors{
Query: nearVectorQueryJSONRest(cfg.ClassName, randomVector(cfg.Dimensions), cfg.Limit),
}
}

if cfg.API == "grpc" {
return QueryWithNeighbors{
Query: nearVectorQueryGrpc(&cfg, randomVector(cfg.Dimensions), cfg.Tenant, 0),
Query: nearVectorQueryGrpc(&cfg, randomVector(cfg.Dimensions), cfg.Tenant, -1),
}
}

return QueryWithNeighbors{}
})
}

func benchmarkNearVectorDuration(cfg Config) Results {

var samples sampledResults

startTime := time.Now()

var results Results
iterations := 0
for time.Since(startTime) < time.Duration(cfg.QueryDuration)*time.Second {
results = benchmarkNearVector(cfg)
samples.Min = append(samples.Min, results.Min)
samples.Max = append(samples.Max, results.Max)
samples.Mean = append(samples.Mean, results.Mean)
samples.Took = append(samples.Took, results.Took)
samples.QueriesPerSecond = append(samples.QueriesPerSecond, results.QueriesPerSecond)
samples.Results = append(samples.Results, results)
iterations += 1
}

var medianResult Results

medianResult.Min = time.Duration(median(samples.Min))
medianResult.Max = time.Duration(median(samples.Max))
medianResult.Mean = time.Duration(median(samples.Mean))
medianResult.Took = time.Duration(median(samples.Took))
medianResult.QueriesPerSecond = median(samples.QueriesPerSecond)
medianResult.Percentiles = results.Percentiles
medianResult.PercentilesLabels = results.PercentilesLabels
medianResult.Total = results.Total
medianResult.Successful = results.Successful
medianResult.Failed = results.Failed
medianResult.Parallelization = cfg.Parallel

log.WithFields(log.Fields{"iterations": iterations}).Infof("Queried for %d seconds", cfg.QueryDuration)

return medianResult
}

0 comments on commit e4eced3

Please sign in to comment.