Skip to content

Commit

Permalink
feat: vec-380 allow hnsw-cache-expiry to be -1, inf, or infinity, for…
Browse files Browse the repository at this point in the history
… never expire (#18)

* feat: allow hnsw-cache-expiry to be -1, inf, or infinity, to signify never expire

* lint: golangci-lint allow asvec/utils import
  • Loading branch information
dwelch-spike authored Oct 25, 2024
1 parent 74f4ea3 commit be1ca23
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 14 deletions.
1 change: 1 addition & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ linters-settings:
- github.com/aerospike/tools-common-go
- github.com/aerospike/avs-client-go
- asvec/cmd
- asvec/utils
- github.com/spf13/cobra
- github.com/spf13/viper
- github.com/spf13/pflag
Expand Down
2 changes: 2 additions & 0 deletions cmd/flags/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ const (

DefaultIPv4 = "127.0.0.1"
DefaultPort = 5000

Infinity = -1
)

func AddFormatTestFlag(flagSet *pflag.FlagSet, val *int) error {
Expand Down
8 changes: 4 additions & 4 deletions cmd/flags/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
package flags

import (
"asvec/tests"
"asvec/utils"
"testing"
)

Expand Down Expand Up @@ -47,8 +47,8 @@ func TestCredentialsFlag_Type(t *testing.T) {
func TestCredentialsFlag_String(t *testing.T) {
// Test string representation with user and password
flag := CredentialsFlag{
User: StringOptionalFlag{Val: tests.Ptr("username")},
Password: StringOptionalFlag{Val: tests.Ptr("password")},
User: StringOptionalFlag{Val: utils.Ptr("username")},
Password: StringOptionalFlag{Val: utils.Ptr("password")},
}
str := flag.String()
expected := "username:password"
Expand All @@ -58,7 +58,7 @@ func TestCredentialsFlag_String(t *testing.T) {

// Test string representation with user only
flag = CredentialsFlag{
User: StringOptionalFlag{Val: tests.Ptr("username")},
User: StringOptionalFlag{Val: utils.Ptr("username")},
Password: StringOptionalFlag{},
}
str = flag.String()
Expand Down
8 changes: 4 additions & 4 deletions cmd/flags/hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,29 +43,29 @@ func (cf *BatchingFlags) NewSLogAttr() []any {

type CachingFlags struct {
MaxEntries Uint64OptionalFlag
Expiry DurationOptionalFlag
Expiry InfDurationOptionalFlag
}

func NewHnswCachingFlags() *CachingFlags {
return &CachingFlags{
MaxEntries: Uint64OptionalFlag{},
Expiry: DurationOptionalFlag{},
Expiry: InfDurationOptionalFlag{},
}
}

//nolint:lll // For readability
func (cf *CachingFlags) NewFlagSet() *pflag.FlagSet {
flagSet := &pflag.FlagSet{}
flagSet.Var(&cf.MaxEntries, HnswCacheMaxEntries, "Maximum number of entries to cache.")
flagSet.Var(&cf.Expiry, HnswCacheExpiry, "A cache entry will expire after this amount of time has passed since the entry was added to cache")
flagSet.Var(&cf.Expiry, HnswCacheExpiry, "A cache entry will expire after this amount of time has passed since the entry was added to cache, or 'inf' to never expire.")

return flagSet
}

func (cf *CachingFlags) NewSLogAttr() []any {
return []any{
slog.Any(HnswCacheMaxEntries, cf.MaxEntries.Val),
slog.Any(HnswCacheExpiry, cf.Expiry.Val),
slog.String(HnswCacheExpiry, cf.Expiry.String()),
}
}

Expand Down
52 changes: 52 additions & 0 deletions cmd/flags/optionals.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package flags

import (
"asvec/utils"
"fmt"
"strconv"
"strings"
"time"
)

Expand Down Expand Up @@ -206,3 +208,53 @@ func (f *DurationOptionalFlag) Int64() *int64 {

return &milli
}

// InfDurationOptionalFlag is a flag that can be either a time.duration or infinity.
// It is used for flags like --hnsw-cache-expiry which can be set to "infinity"
type InfDurationOptionalFlag struct {
duration DurationOptionalFlag
isInfinite bool
}

func (f *InfDurationOptionalFlag) Set(val string) error {
err := f.duration.Set(val)
if err == nil {
return nil
}

val = strings.ToLower(val)

if val == "inf" || val == "infinity" || val == "-1" {
f.isInfinite = true
} else {
return fmt.Errorf("invalid duration %s", val)
}

return nil
}

func (f *InfDurationOptionalFlag) Type() string {
return "time.Duration"
}

func (f *InfDurationOptionalFlag) String() string {
if f.isInfinite {
return "infinity"
}

if f.duration.Val != nil {
return f.duration.String()
}

return optionalEmptyString
}

// Uint64 returns the duration as a uint64. If the duration is infinite, it returns -1.
// The AVS server uses -1 for cache expiry to represent infinity or never expire.
func (f *InfDurationOptionalFlag) Int64() *int64 {
if f.isInfinite {
return utils.Ptr(int64(Infinity))
}

return f.duration.Int64()
}
40 changes: 40 additions & 0 deletions cmd/flags/optionals_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,43 @@ func (suite *OptionalFlagSuite) TestDurationOptionalFlag() {
suite.T().Errorf("Expected error, got nil")
}
}

func (suite *OptionalFlagSuite) TestInfDurationOptionalFlag() {
f := &InfDurationOptionalFlag{}

err := f.Set("inf")
if err != nil {
suite.T().Errorf("Unexpected error: %v", err)
}

suite.Equal("infinity", f.String())
suite.Equal(int64(-1), *f.Int64())
f = &InfDurationOptionalFlag{}

err = f.Set("infinity")
if err != nil {
suite.T().Errorf("Unexpected error: %v", err)
}

suite.Equal("infinity", f.String())
suite.Equal(int64(-1), *f.Int64())
f = &InfDurationOptionalFlag{}

err = f.Set("-1")
if err != nil {
suite.T().Errorf("Unexpected error: %v", err)
}

suite.Equal("infinity", f.String())
suite.Equal(int64(-1), *f.Int64())
f = &InfDurationOptionalFlag{}

err = f.Set("20m")
if err != nil {
suite.T().Errorf("Unexpected error: %v", err)
}

expectedDuration := time.Duration(20) * time.Minute
suite.Equal(expectedDuration.String(), f.String())
suite.Equal(expectedDuration.Milliseconds(), *f.Int64())
}
9 changes: 3 additions & 6 deletions tests/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package tests

import (
"asvec/utils"
"context"
"crypto/tls"
"crypto/x509"
Expand All @@ -18,10 +19,6 @@ import (
"github.com/aerospike/tools-common-go/client"
)

func Ptr[T any](v T) *T {
return &v
}

func CreateFlagStr(name, value string) string {
return fmt.Sprintf("--%s %s", name, value)
}
Expand Down Expand Up @@ -195,7 +192,7 @@ func (idb *IndexDefinitionBuilder) Build() *protos.IndexDefinition {
Namespace: idb.namespace,
},
Dimensions: uint32(idb.dimension),
VectorDistanceMetric: Ptr(idb.vectorDistanceMetric),
VectorDistanceMetric: utils.Ptr(idb.vectorDistanceMetric),
Field: idb.vectorField,
// Storage: ,
Params: &protos.IndexDefinition_HnswParams{
Expand All @@ -214,7 +211,7 @@ func (idb *IndexDefinitionBuilder) Build() *protos.IndexDefinition {
Namespace: idb.namespace,
},
Dimensions: uint32(idb.dimension),
VectorDistanceMetric: Ptr(idb.vectorDistanceMetric),
VectorDistanceMetric: utils.Ptr(idb.vectorDistanceMetric),
Field: idb.vectorField,
Storage: &protos.IndexStorage{},
Params: &protos.IndexDefinition_HnswParams{
Expand Down
5 changes: 5 additions & 0 deletions utils/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package utils

func Ptr[T any](v T) *T {
return &v
}

0 comments on commit be1ca23

Please sign in to comment.