diff --git a/.golangci.yml b/.golangci.yml index 9099fa8..f4ce6fe 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -25,6 +25,7 @@ linters-settings: - github.com/aerospike/tools-common-go - github.com/aerospike/avs-client-go - asvec/cmd + - asvec/utils - github.com/spf13/cobra - github.com/spf13/viper - github.com/spf13/pflag diff --git a/cmd/flags/constants.go b/cmd/flags/constants.go index a590452..e878176 100644 --- a/cmd/flags/constants.go +++ b/cmd/flags/constants.go @@ -80,6 +80,8 @@ const ( DefaultIPv4 = "127.0.0.1" DefaultPort = 5000 + + Infinity = -1 ) func AddFormatTestFlag(flagSet *pflag.FlagSet, val *int) error { diff --git a/cmd/flags/credentials_test.go b/cmd/flags/credentials_test.go index d3ca3b5..6672b39 100644 --- a/cmd/flags/credentials_test.go +++ b/cmd/flags/credentials_test.go @@ -3,7 +3,7 @@ package flags import ( - "asvec/tests" + "asvec/utils" "testing" ) @@ -47,8 +47,8 @@ func TestCredentialsFlag_Type(t *testing.T) { func TestCredentialsFlag_String(t *testing.T) { // Test string representation with user and password flag := CredentialsFlag{ - User: StringOptionalFlag{Val: tests.Ptr("username")}, - Password: StringOptionalFlag{Val: tests.Ptr("password")}, + User: StringOptionalFlag{Val: utils.Ptr("username")}, + Password: StringOptionalFlag{Val: utils.Ptr("password")}, } str := flag.String() expected := "username:password" @@ -58,7 +58,7 @@ func TestCredentialsFlag_String(t *testing.T) { // Test string representation with user only flag = CredentialsFlag{ - User: StringOptionalFlag{Val: tests.Ptr("username")}, + User: StringOptionalFlag{Val: utils.Ptr("username")}, Password: StringOptionalFlag{}, } str = flag.String() diff --git a/cmd/flags/hnsw.go b/cmd/flags/hnsw.go index 1ac3108..c5c908c 100644 --- a/cmd/flags/hnsw.go +++ b/cmd/flags/hnsw.go @@ -43,13 +43,13 @@ func (cf *BatchingFlags) NewSLogAttr() []any { type CachingFlags struct { MaxEntries Uint64OptionalFlag - Expiry DurationOptionalFlag + Expiry InfDurationOptionalFlag } func NewHnswCachingFlags() *CachingFlags { return &CachingFlags{ MaxEntries: Uint64OptionalFlag{}, - Expiry: DurationOptionalFlag{}, + Expiry: InfDurationOptionalFlag{}, } } @@ -57,7 +57,7 @@ func NewHnswCachingFlags() *CachingFlags { func (cf *CachingFlags) NewFlagSet() *pflag.FlagSet { flagSet := &pflag.FlagSet{} flagSet.Var(&cf.MaxEntries, HnswCacheMaxEntries, "Maximum number of entries to cache.") - flagSet.Var(&cf.Expiry, HnswCacheExpiry, "A cache entry will expire after this amount of time has passed since the entry was added to cache") + flagSet.Var(&cf.Expiry, HnswCacheExpiry, "A cache entry will expire after this amount of time has passed since the entry was added to cache, or 'inf' to never expire.") return flagSet } @@ -65,7 +65,7 @@ func (cf *CachingFlags) NewFlagSet() *pflag.FlagSet { func (cf *CachingFlags) NewSLogAttr() []any { return []any{ slog.Any(HnswCacheMaxEntries, cf.MaxEntries.Val), - slog.Any(HnswCacheExpiry, cf.Expiry.Val), + slog.String(HnswCacheExpiry, cf.Expiry.String()), } } diff --git a/cmd/flags/optionals.go b/cmd/flags/optionals.go index 5b2deec..8fbe1f6 100644 --- a/cmd/flags/optionals.go +++ b/cmd/flags/optionals.go @@ -1,8 +1,10 @@ package flags import ( + "asvec/utils" "fmt" "strconv" + "strings" "time" ) @@ -206,3 +208,53 @@ func (f *DurationOptionalFlag) Int64() *int64 { return &milli } + +// InfDurationOptionalFlag is a flag that can be either a time.duration or infinity. +// It is used for flags like --hnsw-cache-expiry which can be set to "infinity" +type InfDurationOptionalFlag struct { + duration DurationOptionalFlag + isInfinite bool +} + +func (f *InfDurationOptionalFlag) Set(val string) error { + err := f.duration.Set(val) + if err == nil { + return nil + } + + val = strings.ToLower(val) + + if val == "inf" || val == "infinity" || val == "-1" { + f.isInfinite = true + } else { + return fmt.Errorf("invalid duration %s", val) + } + + return nil +} + +func (f *InfDurationOptionalFlag) Type() string { + return "time.Duration" +} + +func (f *InfDurationOptionalFlag) String() string { + if f.isInfinite { + return "infinity" + } + + if f.duration.Val != nil { + return f.duration.String() + } + + return optionalEmptyString +} + +// Uint64 returns the duration as a uint64. If the duration is infinite, it returns -1. +// The AVS server uses -1 for cache expiry to represent infinity or never expire. +func (f *InfDurationOptionalFlag) Int64() *int64 { + if f.isInfinite { + return utils.Ptr(int64(Infinity)) + } + + return f.duration.Int64() +} diff --git a/cmd/flags/optionals_test.go b/cmd/flags/optionals_test.go index 7947060..08c9f1c 100644 --- a/cmd/flags/optionals_test.go +++ b/cmd/flags/optionals_test.go @@ -110,3 +110,43 @@ func (suite *OptionalFlagSuite) TestDurationOptionalFlag() { suite.T().Errorf("Expected error, got nil") } } + +func (suite *OptionalFlagSuite) TestInfDurationOptionalFlag() { + f := &InfDurationOptionalFlag{} + + err := f.Set("inf") + if err != nil { + suite.T().Errorf("Unexpected error: %v", err) + } + + suite.Equal("infinity", f.String()) + suite.Equal(int64(-1), *f.Int64()) + f = &InfDurationOptionalFlag{} + + err = f.Set("infinity") + if err != nil { + suite.T().Errorf("Unexpected error: %v", err) + } + + suite.Equal("infinity", f.String()) + suite.Equal(int64(-1), *f.Int64()) + f = &InfDurationOptionalFlag{} + + err = f.Set("-1") + if err != nil { + suite.T().Errorf("Unexpected error: %v", err) + } + + suite.Equal("infinity", f.String()) + suite.Equal(int64(-1), *f.Int64()) + f = &InfDurationOptionalFlag{} + + err = f.Set("20m") + if err != nil { + suite.T().Errorf("Unexpected error: %v", err) + } + + expectedDuration := time.Duration(20) * time.Minute + suite.Equal(expectedDuration.String(), f.String()) + suite.Equal(expectedDuration.Milliseconds(), *f.Int64()) +} diff --git a/tests/utils.go b/tests/utils.go index 44900f9..a08b7d9 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -3,6 +3,7 @@ package tests import ( + "asvec/utils" "context" "crypto/tls" "crypto/x509" @@ -18,10 +19,6 @@ import ( "github.com/aerospike/tools-common-go/client" ) -func Ptr[T any](v T) *T { - return &v -} - func CreateFlagStr(name, value string) string { return fmt.Sprintf("--%s %s", name, value) } @@ -195,7 +192,7 @@ func (idb *IndexDefinitionBuilder) Build() *protos.IndexDefinition { Namespace: idb.namespace, }, Dimensions: uint32(idb.dimension), - VectorDistanceMetric: Ptr(idb.vectorDistanceMetric), + VectorDistanceMetric: utils.Ptr(idb.vectorDistanceMetric), Field: idb.vectorField, // Storage: , Params: &protos.IndexDefinition_HnswParams{ @@ -214,7 +211,7 @@ func (idb *IndexDefinitionBuilder) Build() *protos.IndexDefinition { Namespace: idb.namespace, }, Dimensions: uint32(idb.dimension), - VectorDistanceMetric: Ptr(idb.vectorDistanceMetric), + VectorDistanceMetric: utils.Ptr(idb.vectorDistanceMetric), Field: idb.vectorField, Storage: &protos.IndexStorage{}, Params: &protos.IndexDefinition_HnswParams{ diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 0000000..947538c --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,5 @@ +package utils + +func Ptr[T any](v T) *T { + return &v +}