From 25b85afc8953da51dc51e696a4d129f12b1f00eb Mon Sep 17 00:00:00 2001 From: Daniel Levi-Minzi <51272568+dleviminzi@users.noreply.github.com> Date: Tue, 23 Jul 2024 12:04:15 -0400 Subject: [PATCH] l1 distance (#39) * initial work on l1 * l1 int8 neon implementation * tweak l1 int8 and add test * broken overflow still * some progress on l1 * change to i32 instead of i64 * remove comment * ignore poetry stuff * unrolled l1 int8 and format * remove comments --- .gitignore | 3 + sqlite-vec.c | 207 ++++++++++++++++++++++++++++++++++++++--- tests/test-loadable.py | 72 +++++++++++++- 3 files changed, 267 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index a64dfe2..38f9876 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,6 @@ examples/sotu sqlite-vec.h tmp/ + +poetry.lock +pyproject.toml \ No newline at end of file diff --git a/sqlite-vec.c b/sqlite-vec.c index 9b02872..fdbf255 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -230,6 +230,100 @@ static f32 l2_sqr_int8_neon(const void *pVect1v, const void *pVect2v, return sqrtf(sum_scalar); } + +static i32 l1_int8_neon(const void *pVect1v, const void *pVect2v, + const void *qty_ptr) { + i8 *pVect1 = (i8 *)pVect1v; + i8 *pVect2 = (i8 *)pVect2v; + size_t qty = *((size_t *)qty_ptr); + + const int8_t *pEnd1 = pVect1 + qty; + + int32x4_t acc1 = vdupq_n_s32(0); + int32x4_t acc2 = vdupq_n_s32(0); + int32x4_t acc3 = vdupq_n_s32(0); + int32x4_t acc4 = vdupq_n_s32(0); + + while (pVect1 < pEnd1 - 63) { + int8x16_t v1 = vld1q_s8(pVect1); + int8x16_t v2 = vld1q_s8(pVect2); + int8x16_t diff1 = vabdq_s8(v1, v2); + acc1 = vaddq_s32(acc1, vpaddlq_u16(vpaddlq_u8(diff1))); + + v1 = vld1q_s8(pVect1 + 16); + v2 = vld1q_s8(pVect2 + 16); + int8x16_t diff2 = vabdq_s8(v1, v2); + acc2 = vaddq_s32(acc2, vpaddlq_u16(vpaddlq_u8(diff2))); + + v1 = vld1q_s8(pVect1 + 32); + v2 = vld1q_s8(pVect2 + 32); + int8x16_t diff3 = vabdq_s8(v1, v2); + acc3 = vaddq_s32(acc3, vpaddlq_u16(vpaddlq_u8(diff3))); + + v1 = vld1q_s8(pVect1 + 48); + v2 = vld1q_s8(pVect2 + 48); + int8x16_t diff4 = vabdq_s8(v1, v2); + acc4 = vaddq_s32(acc4, vpaddlq_u16(vpaddlq_u8(diff4))); + + pVect1 += 64; + pVect2 += 64; + } + + while (pVect1 < pEnd1 - 15) { + int8x16_t v1 = vld1q_s8(pVect1); + int8x16_t v2 = vld1q_s8(pVect2); + int8x16_t diff = vabdq_s8(v1, v2); + acc1 = vaddq_s32(acc1, vpaddlq_u16(vpaddlq_u8(diff))); + pVect1 += 16; + pVect2 += 16; + } + + int32x4_t acc = vaddq_s32(vaddq_s32(acc1, acc2), vaddq_s32(acc3, acc4)); + + int32_t sum = 0; + while (pVect1 < pEnd1) { + int32_t diff = abs((int32_t)*pVect1 - (int32_t)*pVect2); + sum += diff; + pVect1++; + pVect2++; + } + + return vaddvq_s32(acc) + sum; +} + +static double l1_f32_neon(const void *pVect1v, const void *pVect2v, + const void *qty_ptr) { + f32 *pVect1 = (f32 *)pVect1v; + f32 *pVect2 = (f32 *)pVect2v; + size_t qty = *((size_t *)qty_ptr); + + const f32 *pEnd1 = pVect1 + qty; + float64x2_t acc = vdupq_n_f64(0); + + while (pVect1 < pEnd1 - 3) { + float32x4_t v1 = vld1q_f32(pVect1); + float32x4_t v2 = vld1q_f32(pVect2); + pVect1 += 4; + pVect2 += 4; + + // f32x4 -> f64x2 pad for overflow + float64x2_t low_diff = vabdq_f64(vcvt_f64_f32(vget_low_f32(v1)), + vcvt_f64_f32(vget_low_f32(v2))); + float64x2_t high_diff = + vabdq_f64(vcvt_high_f64_f32(v1), vcvt_high_f64_f32(v2)); + + acc = vaddq_f64(acc, vaddq_f64(low_diff, high_diff)); + } + + double sum = 0; + while (pVect1 < pEnd1) { + sum += fabs((double)*pVect1 - (double)*pVect2); + pVect1++; + pVect2++; + } + + return vaddvq_f64(acc) + sum; +} #endif static f32 l2_sqr_float(const void *pVect1v, const void *pVect2v, @@ -286,6 +380,54 @@ static f32 distance_l2_sqr_int8(const void *a, const void *b, const void *d) { return l2_sqr_int8(a, b, d); } +static i32 l1_int8(const void *pA, const void *pB, const void *pD) { + i8 *a = (i8 *)pA; + i8 *b = (i8 *)pB; + size_t d = *((size_t *)pD); + + i32 res = 0; + for (size_t i = 0; i < d; i++) { + res += abs(*a - *b); + a++; + b++; + } + + return res; +} + +static i32 distance_l1_int8(const void *a, const void *b, const void *d) { +#ifdef SQLITE_VEC_ENABLE_NEON + if ((*(const size_t *)d) > 15) { + return l1_int8_neon(a, b, d); + } +#endif + return l1_int8(a, b, d); +} + +static double l1_f32(const void *pA, const void *pB, const void *pD) { + f32 *a = (f32 *)pA; + f32 *b = (f32 *)pB; + size_t d = *((size_t *)pD); + + double res = 0; + for (size_t i = 0; i < d; i++) { + res += fabs((double)*a - (double)*b); + a++; + b++; + } + + return res; +} + +static double distance_l1_f32(const void *a, const void *b, const void *d) { +#ifdef SQLITE_VEC_ENABLE_NEON + if ((*(const size_t *)d) > 3) { + return l1_f32_neon(a, b, d); + } +#endif + return l1_f32(a, b, d); +} + static f32 distance_cosine_float(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { f32 *pVect1 = (f32 *)pVect1v; @@ -1040,6 +1182,48 @@ static void vec_distance_l2(sqlite3_context *context, int argc, bCleanup(b); return; } + +static void vec_distance_l1(sqlite3_context *context, int argc, + sqlite3_value **argv) { + assert(argc == 2); + int rc; + void *a, *b; + size_t dimensions; + vector_cleanup aCleanup, bCleanup; + char *error; + enum VectorElementType elementType; + rc = ensure_vector_match(argv[0], argv[1], &a, &b, &elementType, &dimensions, + &aCleanup, &bCleanup, &error); + if (rc != SQLITE_OK) { + sqlite3_result_error(context, error, -1); + sqlite3_free(error); + return; + } + + switch (elementType) { + case SQLITE_VEC_ELEMENT_TYPE_BIT: { + sqlite3_result_error( + context, "Cannot calculate L1 distance between two bitvectors.", -1); + goto finish; + } + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { + double result = distance_l1_f32(a, b, &dimensions); + sqlite3_result_double(context, result); + goto finish; + } + case SQLITE_VEC_ELEMENT_TYPE_INT8: { + i64 result = distance_l1_int8(a, b, &dimensions); + sqlite3_result_int(context, result); + goto finish; + } + } + +finish: + aCleanup(a); + bCleanup(b); + return; +} + static void vec_distance_hamming(sqlite3_context *context, int argc, sqlite3_value **argv) { assert(argc == 2); @@ -2608,9 +2792,9 @@ static int vec_npy_eachOpen(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor) { static int vec_npy_eachClose(sqlite3_vtab_cursor *cur) { vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur; if (pCur->file) { - #ifndef SQLITE_VEC_OMIT_FS +#ifndef SQLITE_VEC_OMIT_FS fclose(pCur->file); - #endif +#endif pCur->file = NULL; } if (pCur->chunksBuffer) { @@ -2664,9 +2848,9 @@ static int vec_npy_eachFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)pVtabCursor; if (pCur->file) { - #ifndef SQLITE_VEC_OMIT_FS +#ifndef SQLITE_VEC_OMIT_FS fclose(pCur->file); - #endif +#endif pCur->file = NULL; } if (pCur->chunksBuffer) { @@ -2679,7 +2863,7 @@ static int vec_npy_eachFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, struct VecNpyFile *f = NULL; - #ifndef SQLITE_VEC_OMIT_FS +#ifndef SQLITE_VEC_OMIT_FS if ((f = sqlite3_value_pointer(argv[0], SQLITE_VEC_NPY_FILE_NAME))) { FILE *file = fopen(f->path, "r"); if (!file) { @@ -2689,15 +2873,15 @@ static int vec_npy_eachFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, rc = parse_npy_file(pVtabCursor->pVtab, file, pCur); if (rc != SQLITE_OK) { - #ifndef SQLITE_VEC_OMIT_FS +#ifndef SQLITE_VEC_OMIT_FS fclose(file); - #endif +#endif return rc; } } else - #endif - { +#endif + { const unsigned char *input = sqlite3_value_blob(argv[0]); int inputLength = sqlite3_value_bytes(argv[0]); @@ -2744,7 +2928,7 @@ static int vec_npy_eachNext(sqlite3_vtab_cursor *cur) { return SQLITE_OK; } - #ifndef SQLITE_VEC_OMIT_FS +#ifndef SQLITE_VEC_OMIT_FS // else: input is a file pCur->currentChunkIndex++; if (pCur->currentChunkIndex >= pCur->currentChunkSize) { @@ -2757,7 +2941,7 @@ static int vec_npy_eachNext(sqlite3_vtab_cursor *cur) { } pCur->currentChunkIndex = 0; } - #endif +#endif return SQLITE_OK; } @@ -6658,6 +6842,7 @@ __declspec(dllexport) //{"vec_version", _static_text_func, 0, DEFAULT_FLAGS, (void *) SQLITE_VEC_VERSION }, //{"vec_debug", _static_text_func, 0, DEFAULT_FLAGS, (void *) SQLITE_VEC_DEBUG_STRING }, {"vec_distance_l2", vec_distance_l2, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, }, + {"vec_distance_l1", vec_distance_l1, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, NULL }, {"vec_distance_hamming",vec_distance_hamming, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, }, {"vec_distance_cosine", vec_distance_cosine, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, }, {"vec_length", vec_length, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE, }, diff --git a/tests/test-loadable.py b/tests/test-loadable.py index 293c667..6624374 100644 --- a/tests/test-loadable.py +++ b/tests/test-loadable.py @@ -104,6 +104,7 @@ def spread_args(args): "vec_debug", "vec_distance_cosine", "vec_distance_hamming", + "vec_distance_l1", "vec_distance_l2", "vec_f32", "vec_int8", @@ -309,6 +310,68 @@ def test_vec_distance_hamming(): db.execute("select vec_distance_hamming(vec_int8(X'FF'), vec_int8(X'FF'))") +def test_vec_distance_l1(): + vec_distance_l1 = lambda *args, a="?", b="?": db.execute( + f"select vec_distance_l1({a}, {b})", args + ).fetchone()[0] + + def check(a, b, dtype=np.float32): + if dtype == np.float32: + transform = "?" + elif dtype == np.int8: + transform = "vec_int8(?)" + + a_sql_t = np.array(a, dtype=dtype) + b_sql_t = np.array(b, dtype=dtype) + + x = vec_distance_l1(a_sql_t, b_sql_t, a=transform, b=transform) + # dont use dtype here bc overflow + y = np.sum(np.abs(np.array(a) - np.array(b))) + assert isclose(x, y, abs_tol=1e-6) + + check([1, 2, 3], [-9, -8, -7], dtype=np.int8) + # check overflow + check([127] * 20, [-128] * 20, dtype=np.int8) + check([-128, 127], [127, -128], dtype=np.int8) + check( + [1, 2, 3, 4, 5, 6, 7, 8, 1, 1, 2, 3, 4, 5, 6, 7, 8, 1], + [1, 20, 38, 23, 29, 4, 10, 9, 3, 1, 20, 38, 23, 29, 4, 10, 9, 3], + dtype=np.int8, + ) + check([0] * 20, [0] * 20, dtype=np.int8) + check( + [5, 15, -20, 5, 15, -20, 5, 15, -20, 5, 15, -20, 5, 15, -20, 5, 15, -20], + [5, 15, -20, 5, 15, -20, 5, 15, -20, 5, 15, -20, 5, 15, -20, 5, 15, -20], + dtype=np.int8, + ) + check([100] * 20, [-100] * 20, dtype=np.int8) + check([127] * 1000000, [-128] * 1000000, dtype=np.int8) + + check( + [1.2, 0.1, 0.5, 0.9, 1.4, 4.5], + [0.4, -0.4, 0.1, 0.1, 0.5, 0.9], + dtype=np.float32, + ) + check([1.0, 2.0, 3.0], [-1.0, -2.0, -3.0], dtype=np.float32) + check( + [1e10, 2e10, np.finfo(np.float32).max], + [-1e10, -2e10, np.finfo(np.float32).min], + dtype=np.float32, + ) + # overflow in leftover elements + check( + [1e10, 2e10, 1e10, 2e10, np.finfo(np.float32).max], + [-1e10, -2e10, -1e10, -2e10, np.finfo(np.float32).min], + dtype=np.float32, + ) + # overflow in neon elements + check( + [np.finfo(np.float32).max, 1e10, 2e10, 1e10, 2e10], + [np.finfo(np.float32).min, -1e10, -2e10, -1e10, -2e10], + dtype=np.float32, + ) + + def test_vec_distance_l2(): vec_distance_l2 = lambda *args, a="?", b="?": db.execute( f"select vec_distance_l2({a}, {b})", args @@ -319,11 +382,12 @@ def check(a, b, dtype=np.float32): transform = "?" elif dtype == np.int8: transform = "vec_int8(?)" - a = np.array(a, dtype=dtype) - b = np.array(b, dtype=dtype) - x = vec_distance_l2(a, b, a=transform, b=transform) - y = npy_l2(a, b) + a_sql_t = np.array(a, dtype=dtype) + b_sql_t = np.array(b, dtype=dtype) + + x = vec_distance_l2(a_sql_t, b_sql_t, a=transform, b=transform) + y = npy_l2(np.array(a), np.array(b)) assert isclose(x, y, abs_tol=1e-6) check([1.2, 0.1], [0.4, -0.4])