Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Jesse Schmidt committed Jul 10, 2024
1 parent 9e69a5e commit 7d69032
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 43 deletions.
73 changes: 41 additions & 32 deletions cmd/flags/constants.go
Original file line number Diff line number Diff line change
@@ -1,36 +1,45 @@
package flags

const (
LogLevel = "log-level"
Seeds = "seeds"
Host = "host"
ListenerName = "listener-name"
AuthUser = "user"
AuthPassword = "password"
Name = "name"
NewPassword = "new-password"
Roles = "roles"
Namespace = "namespace"
Sets = "sets"
Yes = "yes"
IndexName = "index-name"
VectorField = "vector-field"
Dimension = "dimension"
DistanceMetric = "distance-metric"
IndexMeta = "index-meta"
Timeout = "timeout"
Verbose = "verbose"
StorageNamespace = "storage-namespace"
StorageSet = "storage-set"
MaxEdges = "hnsw-max-edges"
ConstructionEf = "hnsw-ef-construction"
Ef = "hnsw-ef"
BatchMaxRecords = "hnsw-batch-max-records"
BatchInterval = "hnsw-batch-interval"
TLSProtocols = "tls-protocols"
TLSCaFile = "tls-cafile"
TLSCaPath = "tls-capath"
TLSCertFile = "tls-certfile"
TLSKeyFile = "tls-keyfile"
TLSKeyFilePass = "tls-keyfile-password" //nolint:gosec // Not a credential
LogLevel = "log-level"
Seeds = "seeds"
Host = "host"
ListenerName = "listener-name"
AuthUser = "user"
AuthPassword = "password"
Name = "name"
NewPassword = "new-password"
Roles = "roles"
Namespace = "namespace"
Sets = "sets"
Yes = "yes"
IndexName = "index-name"
VectorField = "vector-field"
Dimension = "dimension"
DistanceMetric = "distance-metric"
IndexMeta = "index-meta"
Timeout = "timeout"
Verbose = "verbose"
StorageNamespace = "storage-namespace"
StorageSet = "storage-set"
MaxEdges = "hnsw-max-edges"
ConstructionEf = "hnsw-ef-construction"
Ef = "hnsw-ef"
HnswMaxMemQueueSize = "hnsw-max-mem-queue-size"
BatchMaxRecords = "hnsw-batch-max-records"
BatchInterval = "hnsw-batch-interval"
HnswCacheMaxEntries = "hnsw-cache-max-entries"
HnswCacheExpiry = "hnsw-cache-expiry"
HnswHealerMaxScanRatePerNode = "hnsw-healer-max-scan-rate-per-node"
HnswHealerMaxScanPageSize = "hnsw-healer-max-scan-page-size"
HnswHealerReindexPercent = "hnsw-healer-reindex-percent"
HnswHealerScheduleDelay = "hnsw-healer-schedule-delay"
HnswHealerParallelism = "hnsw-healer-parallelism"
HnswMergeParallelism = "hnsw-merge-parallelism"
TLSProtocols = "tls-protocols"
TLSCaFile = "tls-cafile"
TLSCaPath = "tls-capath"
TLSCertFile = "tls-certfile"
TLSKeyFile = "tls-keyfile"
TLSKeyFilePass = "tls-keyfile-password" //nolint:gosec // Not a credential
)
124 changes: 124 additions & 0 deletions cmd/flags/hnsw.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package flags

import (
"log/slog"

commonFlags "github.com/aerospike/tools-common-go/flags"
"github.com/spf13/pflag"
)

type BatchingFlags struct {
MaxRecords Uint32OptionalFlag
Interval Uint32OptionalFlag
}

func NewBatchingFlags() *BatchingFlags {
return &BatchingFlags{
MaxRecords: Uint32OptionalFlag{},
Interval: Uint32OptionalFlag{},
}
}

func (cf *BatchingFlags) NewFlagSet() *pflag.FlagSet {
flagSet := &pflag.FlagSet{}
flagSet.Var(&cf.MaxRecords, BatchMaxRecords, commonFlags.DefaultWrapHelpString("Maximum number of records to fit in a batch. The default value is 10000.")) //nolint:lll // For readability
flagSet.Var(&cf.Interval, BatchInterval, commonFlags.DefaultWrapHelpString("The maximum amount of time in milliseconds to wait before finalizing a batch. The default value is 10000.")) //nolint:lll // For readability

return flagSet
}

func (cf *BatchingFlags) NewSLogAttr() []any {
return []any{
slog.Any(BatchMaxRecords, cf.MaxRecords.Val),
slog.Any(BatchInterval, cf.Interval.Val),
}
}

type CachingFlags struct {
maxEntries Uint64OptionalFlag
expiry Uint64OptionalFlag
}

func NewCachingFlags() *CachingFlags {
return &CachingFlags{
maxEntries: Uint64OptionalFlag{},
expiry: Uint64OptionalFlag{},
}
}

func (cf *CachingFlags) NewFlagSet() *pflag.FlagSet {
flagSet := &pflag.FlagSet{} //nolint:lll // For readability
flagSet.Var(&cf.maxEntries, HnswCacheMaxEntries, commonFlags.DefaultWrapHelpString("TODO")) //nolint:lll // For readability
flagSet.Var(&cf.expiry, HnswCacheExpiry, commonFlags.DefaultWrapHelpString("TODO")) //nolint:lll // For readability

return flagSet
}

func (cf *CachingFlags) NewSLogAttr() []any {
return []any{
slog.Any(HnswCacheMaxEntries, cf.maxEntries.Val),
slog.Any(HnswCacheExpiry, cf.expiry.Val),
}
}

type HealerFlags struct {
maxScanRatePerNode Uint32OptionalFlag
maxScanPageSize Uint32OptionalFlag
reindexPercent Float32OptionalFlag
scheduleDelay Uint64OptionalFlag
parallelism Uint32OptionalFlag
}

func NewHealerFlags() *HealerFlags {
return &HealerFlags{
maxScanRatePerNode: Uint32OptionalFlag{},
maxScanPageSize: Uint32OptionalFlag{},
reindexPercent: Float32OptionalFlag{},
scheduleDelay: Uint64OptionalFlag{},
parallelism: Uint32OptionalFlag{},
}
}

func (cf *HealerFlags) NewFlagSet() *pflag.FlagSet {
flagSet := &pflag.FlagSet{} //nolint:lll // For readability
flagSet.Var(&cf.maxScanRatePerNode, HnswHealerMaxScanRatePerNode, commonFlags.DefaultWrapHelpString("TODO")) //nolint:lll // For readability
flagSet.Var(&cf.maxScanPageSize, HnswHealerMaxScanPageSize, commonFlags.DefaultWrapHelpString("TODO")) //nolint:lll // For readability
flagSet.Var(&cf.reindexPercent, HnswHealerReindexPercent, commonFlags.DefaultWrapHelpString("TODO")) //nolint:lll // For readability
flagSet.Var(&cf.scheduleDelay, HnswHealerScheduleDelay, commonFlags.DefaultWrapHelpString("TODO")) //nolint:lll // For readability
flagSet.Var(&cf.parallelism, HnswHealerParallelism, commonFlags.DefaultWrapHelpString("TODO")) //nolint:lll // For readability

return flagSet
}

func (cf *HealerFlags) NewSLogAttr() []any {
return []any{
slog.Any(HnswHealerMaxScanRatePerNode, cf.maxScanRatePerNode.Val),
slog.Any(HnswHealerMaxScanPageSize, cf.maxScanPageSize.Val),
slog.Any(HnswHealerReindexPercent, cf.reindexPercent.Val),
slog.Any(HnswHealerScheduleDelay, cf.scheduleDelay.Val),
slog.Any(HnswHealerParallelism, cf.parallelism.Val),
}
}

type MergeFlags struct {
parallelism Uint32OptionalFlag
}

func NewMergeFlags() *MergeFlags {
return &MergeFlags{
parallelism: Uint32OptionalFlag{},
}
}

func (cf *MergeFlags) NewFlagSet() *pflag.FlagSet {
flagSet := &pflag.FlagSet{} //nolint:lll // For readability
flagSet.Var(&cf.parallelism, HnswMergeParallelism, commonFlags.DefaultWrapHelpString("TODO")) //nolint:lll // For readability

return flagSet
}

func (cf *MergeFlags) NewSLogAttr() []any {
return []any{
slog.Any(HnswMergeParallelism, cf.parallelism.Val),
}
}
47 changes: 47 additions & 0 deletions cmd/flags/optionals.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,53 @@ func (f *Uint32OptionalFlag) String() string {
return optionalEmptyString
}

type Uint64OptionalFlag struct {
Val *uint64
}

func (f *Uint64OptionalFlag) Set(val string) error {
v, err := strconv.ParseUint(val, 0, 64)
f.Val = &v

return err
}

func (f *Uint64OptionalFlag) Type() string {
return "uint64"
}

func (f *Uint64OptionalFlag) String() string {
if f.Val != nil {
return strconv.FormatUint(*f.Val, 10)
}

return optionalEmptyString
}

type Float32OptionalFlag struct {
Val *float32
}

func (f *Float32OptionalFlag) Set(val string) error {
v, err := strconv.ParseFloat(val, 32)
f32Val := float32(v)
f.Val = &f32Val

return err
}

func (f *Float32OptionalFlag) Type() string {
return "float32"
}

func (f *Float32OptionalFlag) String() string {
if f.Val != nil {
return strconv.FormatUint(uint64(*f.Val), 10)
}

return optionalEmptyString
}

type BoolOptionalFlag struct {
Val *bool
}
Expand Down
32 changes: 21 additions & 11 deletions cmd/indexCreate.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,23 @@ var indexCreateFlags = &struct {
hnswMaxEdges flags.Uint32OptionalFlag
hnswEf flags.Uint32OptionalFlag
hnswConstructionEf flags.Uint32OptionalFlag
hnswBatchMaxRecords flags.Uint32OptionalFlag
hnswBatchInterval flags.Uint32OptionalFlag
hnswMaxMemQueueSize flags.Uint32OptionalFlag
hnswBatch flags.BatchingFlags
hnswCache flags.CachingFlags
hnswHealer flags.HealerFlags
hnswMerge flags.MergeFlags
}{
clientFlags: *flags.NewClientFlags(),
storageNamespace: flags.StringOptionalFlag{},
storageSet: flags.StringOptionalFlag{},
hnswMaxEdges: flags.Uint32OptionalFlag{},
hnswEf: flags.Uint32OptionalFlag{},
hnswConstructionEf: flags.Uint32OptionalFlag{},
hnswBatchMaxRecords: flags.Uint32OptionalFlag{},
hnswBatchInterval: flags.Uint32OptionalFlag{},
hnswMaxMemQueueSize: flags.Uint32OptionalFlag{},
hnswBatch: *flags.NewBatchingFlags(),
hnswCache: *flags.NewCachingFlags(),
hnswHealer: *flags.NewHealerFlags(),
hnswMerge: *flags.NewMergeFlags(),
}

func newIndexCreateFlagSet() *pflag.FlagSet {
Expand All @@ -58,10 +64,12 @@ func newIndexCreateFlagSet() *pflag.FlagSet {
flagSet.Var(&indexCreateFlags.hnswMaxEdges, flags.MaxEdges, commonFlags.DefaultWrapHelpString("Maximum number bi-directional links per HNSW vertex. Greater values of 'm' in general provide better recall for data with high dimensionality, while lower values work well for data with lower dimensionality. The storage space required for the index increases proportionally with 'm'.")) //nolint:lll // For readability
flagSet.Var(&indexCreateFlags.hnswConstructionEf, flags.ConstructionEf, commonFlags.DefaultWrapHelpString("The number of candidate nearest neighbors shortlisted during index creation. Larger values provide better recall at the cost of longer index update times. The default is 100.")) //nolint:lll // For readability
flagSet.Var(&indexCreateFlags.hnswEf, flags.Ef, commonFlags.DefaultWrapHelpString("The default number of candidate nearest neighbors shortlisted during search. Larger values provide better recall at the cost of longer search times. The default is 100.")) //nolint:lll // For readability
flagSet.Var(&indexCreateFlags.hnswBatchMaxRecords, flags.BatchMaxRecords, commonFlags.DefaultWrapHelpString("Maximum number of records to fit in a batch. The default value is 10000.")) //nolint:lll // For readability
flagSet.Var(&indexCreateFlags.hnswBatchInterval, flags.BatchInterval, commonFlags.DefaultWrapHelpString("The maximum amount of time in milliseconds to wait before finalizing a batch. The default value is 10000.")) //nolint:lll // For readability
flagSet.Var(&indexCreateFlags.hnswMaxMemQueueSize, flags.HnswMaxMemQueueSize, commonFlags.DefaultWrapHelpString("TODO")) //nolint:lll // For readability //nolint:lll // For readability
flagSet.AddFlagSet(indexCreateFlags.clientFlags.NewClientFlagSet())

flagSet.AddFlagSet(indexCreateFlags.hnswBatch.NewFlagSet())
flagSet.AddFlagSet(indexCreateFlags.hnswCache.NewFlagSet())
flagSet.AddFlagSet(indexCreateFlags.hnswHealer.NewFlagSet())
flagSet.AddFlagSet(indexCreateFlags.hnswMerge.NewFlagSet())
return flagSet
}

Expand Down Expand Up @@ -109,8 +117,10 @@ asvec index create -i myindex -n test -s testset -d 256 -m COSINE --%s vector \
slog.Any(flags.MaxEdges, indexCreateFlags.hnswMaxEdges.String()),
slog.Any(flags.Ef, indexCreateFlags.hnswEf),
slog.Any(flags.ConstructionEf, indexCreateFlags.hnswConstructionEf.String()),
slog.Any(flags.BatchMaxRecords, indexCreateFlags.hnswBatchMaxRecords.String()),
slog.Any(flags.BatchInterval, indexCreateFlags.hnswBatchInterval.String()),
indexCreateFlags.hnswBatch.NewSLogAttr(),
// slog.Any(flags.BatchMaxRecords, indexCreateFlags.hnswBatchMaxRecords.String()),
// slog.Any(flags.BatchInterval,
// indexCreateFlags.hnswBatchInterval.String()), TODO
)...,
)

Expand All @@ -132,8 +142,8 @@ asvec index create -i myindex -n test -s testset -d 256 -m COSINE --%s vector \
Ef: indexCreateFlags.hnswEf.Val,
EfConstruction: indexCreateFlags.hnswConstructionEf.Val,
BatchingParams: &protos.HnswBatchingParams{
MaxRecords: indexCreateFlags.hnswBatchMaxRecords.Val,
Interval: indexCreateFlags.hnswBatchInterval.Val,
MaxRecords: indexCreateFlags.hnswBatch.MaxRecords.Val,
Interval: indexCreateFlags.hnswBatch.Interval.Val,
},
},
}
Expand Down

0 comments on commit 7d69032

Please sign in to comment.