Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Support constaints on distance column in KNN queries, for pagination and range queries #166

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions dbg.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.load dist/vec0


create virtual table vec_items using vec0(
vector float[1]
);

insert into vec_items(rowid, vector)
select value, json_array(value) from generate_series(1, 100);


select vec_to_json(vector), distance
from vec_items
where vector match '[1]'
and k = 5;

select vec_to_json(vector), distance
from vec_items
where vector match '[1]'
and k = 5
and distance > 4.0;
139 changes: 137 additions & 2 deletions sqlite-vec.c
Original file line number Diff line number Diff line change
Expand Up @@ -5305,11 +5305,21 @@ static int vec0Close(sqlite3_vtab_cursor *cur) {
typedef enum {
// If any values are updated, please update the ARCHITECTURE.md docs accordingly!

// ~~~ KNN QUERIES ~~~ //
VEC0_IDXSTR_KIND_KNN_MATCH = '{',
VEC0_IDXSTR_KIND_KNN_K = '}',
VEC0_IDXSTR_KIND_KNN_ROWID_IN = '[',
// argv[i] is a constraint on a PARTITON KEY column in a KNN query
//
VEC0_IDXSTR_KIND_KNN_PARTITON_CONSTRAINT = ']',

// argv[i] is a constraint on the distance column in a KNN query
VEC0_IDXSTR_KIND_KNN_DISTANCE_CONSTRAINT = '*',

// ~~~ POINT QUERIES ~~~ //
VEC0_IDXSTR_KIND_POINT_ID = '!',

// ~~~ ??? ~~~ //
VEC0_IDXSTR_KIND_METADATA_CONSTRAINT = '&',
} vec0_idxstr_kind;

Expand All @@ -5318,11 +5328,22 @@ typedef enum {
typedef enum {
// If any values are updated, please update the ARCHITECTURE.md docs accordingly!

// Equality constraint on a PARTITON KEY column, ex `user_id = 123`
VEC0_PARTITION_OPERATOR_EQ = 'a',

// "Greater than" constraint on a PARTITON KEY column, ex `year > 2024`
VEC0_PARTITION_OPERATOR_GT = 'b',

// "Less than or equal to" constraint on a PARTITON KEY column, ex `year <= 2024`
VEC0_PARTITION_OPERATOR_LE = 'c',

// "Less than" constraint on a PARTITON KEY column, ex `year < 2024`
VEC0_PARTITION_OPERATOR_LT = 'd',

// "Greater than or equal to" constraint on a PARTITON KEY column, ex `year >= 2024`
VEC0_PARTITION_OPERATOR_GE = 'e',

// "Not equal to" constraint on a PARTITON KEY column, ex `year != 2024`
VEC0_PARTITION_OPERATOR_NE = 'f',
} vec0_partition_operator;
typedef enum {
Expand All @@ -5335,6 +5356,15 @@ typedef enum {
VEC0_METADATA_OPERATOR_IN = 'g',
} vec0_metadata_operator;


typedef enum {

VEC0_DISTANCE_CONSTRAINT_GT = 'a',
VEC0_DISTANCE_CONSTRAINT_GE = 'b',
VEC0_DISTANCE_CONSTRAINT_LT = 'c',
VEC0_DISTANCE_CONSTRAINT_LE = 'd',
} vec0_distance_constraint_operator;

static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
vec0_vtab *p = (vec0_vtab *)pVTab;
/**
Expand Down Expand Up @@ -5494,6 +5524,7 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
}
#endif

// find any PARTITION KEY column constraints
for (int i = 0; i < pIdxInfo->nConstraint; i++) {
if (!pIdxInfo->aConstraint[i].usable)
continue;
Expand Down Expand Up @@ -5548,6 +5579,7 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {

}

// find any metadata column constraints
for (int i = 0; i < pIdxInfo->nConstraint; i++) {
if (!pIdxInfo->aConstraint[i].usable)
continue;
Expand Down Expand Up @@ -5644,6 +5676,58 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {

}

// find any distance column constraints
for (int i = 0; i < pIdxInfo->nConstraint; i++) {
if (!pIdxInfo->aConstraint[i].usable)
continue;

int iColumn = pIdxInfo->aConstraint[i].iColumn;
int op = pIdxInfo->aConstraint[i].op;
if(op == SQLITE_INDEX_CONSTRAINT_LIMIT || op == SQLITE_INDEX_CONSTRAINT_OFFSET) {
continue;
}
if(vec0_column_distance_idx(p) != iColumn) {
continue;
}

char value = 0;
switch(op) {
case SQLITE_INDEX_CONSTRAINT_GT: {
value = VEC0_DISTANCE_CONSTRAINT_GT;
break;
}
case SQLITE_INDEX_CONSTRAINT_GE: {
value = VEC0_DISTANCE_CONSTRAINT_GE;
break;
}
case SQLITE_INDEX_CONSTRAINT_LT: {
value = VEC0_DISTANCE_CONSTRAINT_LT;
break;
}
case SQLITE_INDEX_CONSTRAINT_LE: {
value = VEC0_DISTANCE_CONSTRAINT_LE;
break;
}
default: {
// IMP TODO
rc = SQLITE_ERROR;
vtab_set_error(
pVTab,
"Illegal WHERE constraint on distance column in a KNN query. "
"Only one of GT, GE, LT, LE constraints are allowed."
);
goto done;
}
}

pIdxInfo->aConstraintUsage[i].argvIndex = argvIndex++;
pIdxInfo->aConstraintUsage[i].omit = 1;
sqlite3_str_appendchar(idxStr, 1, VEC0_IDXSTR_KIND_KNN_DISTANCE_CONSTRAINT);
sqlite3_str_appendchar(idxStr, 1, value);
sqlite3_str_appendchar(idxStr, 1, '_');
sqlite3_str_appendchar(idxStr, 1, '_');
}



pIdxInfo->idxNum = iMatchVectorTerm;
Expand Down Expand Up @@ -5672,7 +5756,6 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
}
pIdxInfo->needToFreeIdxStr = 1;


rc = SQLITE_OK;

done:
Expand Down Expand Up @@ -6560,12 +6643,15 @@ int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks,
int numValueEntries = (idxStrLength-1) / 4;
assert(numValueEntries == argc);
int hasMetadataFilters = 0;
int hasDistanceConstraints = 0;
for(int i = 0; i < argc; i++) {
int idx = 1 + (i * 4);
char kind = idxStr[idx + 0];
if(kind == VEC0_IDXSTR_KIND_METADATA_CONSTRAINT) {
hasMetadataFilters = 1;
break;
}
else if(kind == VEC0_IDXSTR_KIND_KNN_DISTANCE_CONSTRAINT) {
hasDistanceConstraints = 1;
}
}

Expand Down Expand Up @@ -6752,6 +6838,55 @@ int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks,
chunk_distances[i] = result;
}

if(hasDistanceConstraints) {
for(int i = 0; i < argc; i++) {
int idx = 1 + (i * 4);
char kind = idxStr[idx + 0];
// TODO casts f64 to f32, is that a problem?
f32 target = (f32) sqlite3_value_double(argv[i]);

if(kind != VEC0_IDXSTR_KIND_KNN_DISTANCE_CONSTRAINT) {
continue;
}
vec0_distance_constraint_operator op = idxStr[idx + 1];

switch(op) {
case VEC0_DISTANCE_CONSTRAINT_GE: {
for(int i = 0; i < p->chunk_size;i++) {
if(bitmap_get(b, i) && !(chunk_distances[i] >= target)) {
bitmap_set(b, i, 0);
}
}
break;
}
case VEC0_DISTANCE_CONSTRAINT_GT: {
for(int i = 0; i < p->chunk_size;i++) {
if(bitmap_get(b, i) && !(chunk_distances[i] > target)) {
bitmap_set(b, i, 0);
}
}
break;
}
case VEC0_DISTANCE_CONSTRAINT_LE: {
for(int i = 0; i < p->chunk_size;i++) {
if(bitmap_get(b, i) && !(chunk_distances[i] <= target)) {
bitmap_set(b, i, 0);
}
}
break;
}
case VEC0_DISTANCE_CONSTRAINT_LT: {
for(int i = 0; i < p->chunk_size;i++) {
if(bitmap_get(b, i) && !(chunk_distances[i] < target)) {
bitmap_set(b, i, 0);
}
}
break;
}
}
}
}

int used1;
min_idx(chunk_distances, p->chunk_size, b, chunk_topk_idxs,
min(k, p->chunk_size), bTaken, &used1);
Expand Down