Skip to content

Commit

Permalink
fix tests, KNN/rowids in
Browse files Browse the repository at this point in the history
  • Loading branch information
asg017 committed Nov 13, 2024
1 parent 4cb891a commit 8369dfe
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
36 changes: 29 additions & 7 deletions sqlite-vec.c
Original file line number Diff line number Diff line change
Expand Up @@ -4189,6 +4189,8 @@ void vec0_query_point_data_clear(struct vec0_query_point_data *point_data) {
}

typedef enum {
// If any values are updated, please update the ARCHITECTURE.md docs accordingly!

VEC0_QUERY_PLAN_FULLSCAN = '1',
VEC0_QUERY_PLAN_POINT = '2',
VEC0_QUERY_PLAN_KNN = '3',
Expand Down Expand Up @@ -4649,6 +4651,8 @@ static int vec0Close(sqlite3_vtab_cursor *cur) {
// All the different type of "values" provided to argv/argc in vec0Filter.
// These enums denote the use and purpose of all of them.
typedef enum {
// If any values are updated, please update the ARCHITECTURE.md docs accordingly!

VEC0_IDXSTR_KIND_KNN_MATCH = '{',
VEC0_IDXSTR_KIND_KNN_K = '}',
VEC0_IDXSTR_KIND_KNN_ROWID_IN = '[',
Expand All @@ -4659,6 +4663,8 @@ typedef enum {
// The different SQLITE_INDEX_CONSTRAINT values that vec0 partition key columns
// support, but as characters that fit nicely in idxstr.
typedef enum {
// If any values are updated, please update the ARCHITECTURE.md docs accordingly!

VEC0_PARTITION_OPERATOR_EQ = 'a',
VEC0_PARTITION_OPERATOR_GT = 'b',
VEC0_PARTITION_OPERATOR_LE = 'c',
Expand Down Expand Up @@ -5428,8 +5434,7 @@ int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks,

int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
const char *idxStr, int argc, sqlite3_value **argv) {
UNUSED_PARAMETER(idxStr);
assert(argc >= 2);
assert(argc == (strlen(idxStr)-1) / 4);
int rc;
struct vec0_query_knn_data *knn_data;

Expand All @@ -5450,9 +5455,26 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
}
memset(knn_data, 0, sizeof(*knn_data));

int query_idx =-1;
int k_idx = -1;
int rowid_in_idx = -1;
for(int i = 0; i < argc; i++) {
if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_MATCH) {
query_idx = i;
}
if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_K) {
k_idx = i;
}
if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_ROWID_IN) {
rowid_in_idx = i;
}
}
assert(query_idx >= 0);
assert(k_idx >= 0);

// make sure the query vector matches the vector column (type dimensions etc.)
// TODO not argv[0], source idx from idxStr
rc = vector_from_value(argv[0], &queryVector, &dimensions, &elementType,
rc = vector_from_value(argv[query_idx], &queryVector, &dimensions, &elementType,
&queryVectorCleanup, &pzError);

if (rc != SQLITE_OK) {
Expand Down Expand Up @@ -5485,7 +5507,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
}

// TODO not argv[1], source idx from idxStr
i64 k = sqlite3_value_int64(argv[1]);
i64 k = sqlite3_value_int64(argv[k_idx]);
if (k < 0) {
vtab_set_error(
&p->base, "k value in knn queries must be greater than or equal to 0.");
Expand Down Expand Up @@ -5515,7 +5537,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
// NULL if none were provided, which means a "full" scan.
#if COMPILER_SUPPORTS_VTAB_IN
// TODO fix
if (false /*argc > 2*/) {
if (rowid_in_idx >= 0) {
sqlite3_value *item;
int rc;
arrayRowidsIn = sqlite3_malloc(sizeof(*arrayRowidsIn));
Expand All @@ -5529,8 +5551,8 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
if (rc != SQLITE_OK) {
goto cleanup;
}
for (rc = sqlite3_vtab_in_first(argv[2], &item); rc == SQLITE_OK && item;
rc = sqlite3_vtab_in_next(argv[2], &item)) {
for (rc = sqlite3_vtab_in_first(argv[rowid_in_idx], &item); rc == SQLITE_OK && item;
rc = sqlite3_vtab_in_next(argv[rowid_in_idx], &item)) {
i64 rowid;
if (p->pkIsText) {
rc = vec0_rowid_from_id(p, item, &rowid);
Expand Down
1 change: 0 additions & 1 deletion tests/test-loadable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,7 +1420,6 @@ def test_vec0_point():
assert execute_all(db, "select * from t2 where id = 'xxx'") == []


# @pytest.mark.skip(reason="TODO failing locally for some reason")
def test_vec0_text_pk():
db = connect(EXT_PATH)
db.execute(
Expand Down

0 comments on commit 8369dfe

Please sign in to comment.