From 44aef7a50fdf131fe0846d496e6d23ca0ebdce94 Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Fri, 28 Jun 2024 19:21:50 -0700 Subject: [PATCH] memset 0 all applicable mallocs, fix windows? --- sqlite-vec.c | 44 +++++++++++++++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/sqlite-vec.c b/sqlite-vec.c index 18bf3e9..3e8edb0 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -431,10 +431,13 @@ struct Array { * SQLITE_NOMEM */ int array_init(struct Array *array, size_t element_size, size_t init_capacity) { - void *z = sqlite3_malloc(element_size * init_capacity); + int sz = element_size * init_capacity; + void *z = sqlite3_malloc(sz); if (!z) { return SQLITE_NOMEM; } + memset(z, 0, sz); + array->element_size = element_size; array->length = 0; array->capacity = init_capacity; @@ -870,6 +873,7 @@ static void vec_npy_file(sqlite3_context *context, int argc, sqlite3_result_error_nomem(context); return; } + memset(f, 0, sizeof(*f)); f->path = path; f->pathLength = pathLength; @@ -1081,9 +1085,11 @@ static void vec_quantize_i8(sqlite3_context *context, int argc, fvec_cleanup cleanup; char *err; int rc = fvec_from_value(argv[0], &srcVector, &dimensions, &cleanup, &err); - assert(rc == SQLITE_OK); - i8 *out = sqlite3_malloc(dimensions * sizeof(i8)); - assert(out); + todo_assert(rc == SQLITE_OK); + int sz = dimensions * sizeof(i8); + i8 *out = sqlite3_malloc(sz); + todo_assert(out); + memset(out, 0, sz); if (argc == 2) { if ((sqlite3_value_type(argv[1]) != SQLITE_TEXT) || @@ -1133,12 +1139,14 @@ static void vec_quantize_binary(sqlite3_context *context, int argc, } if (elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { - u8 *out = sqlite3_malloc(dimensions / CHAR_BIT); + int sz = dimensions / CHAR_BIT; + u8 *out = sqlite3_malloc(sz); if (!out) { cleanup(vector); sqlite3_result_error_code(context, SQLITE_NOMEM); return; } + memset(out, 0, sz); for (size_t i = 0; i < dimensions; i++) { int res = ((f32 *)vector)[i] > 0.0; out[i / 8] |= (res << (i % 8)); @@ -1146,12 +1154,14 @@ static void vec_quantize_binary(sqlite3_context *context, int argc, sqlite3_result_blob(context, out, dimensions / CHAR_BIT, sqlite3_free); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT); } else if (elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) { - u8 *out = sqlite3_malloc(dimensions / CHAR_BIT); + int sz = dimensions / CHAR_BIT; + u8 *out = sqlite3_malloc(sz); if (!out) { cleanup(vector); sqlite3_result_error_code(context, SQLITE_NOMEM); return; } + memset(out, 0, sz); for (size_t i = 0; i < dimensions; i++) { int res = ((i8 *)vector)[i] > 0; out[i / 8] |= (res << (i % 8)); @@ -1193,6 +1203,7 @@ static void vec_add(sqlite3_context *context, int argc, sqlite3_value **argv) { sqlite3_result_error_nomem(context); goto finish; } + memset(out, 0, outSize); for (size_t i = 0; i < dimensions; i++) { out[i] = ((f32 *)a)[i] + ((f32 *)b)[i]; } @@ -1207,6 +1218,7 @@ static void vec_add(sqlite3_context *context, int argc, sqlite3_value **argv) { sqlite3_result_error_nomem(context); goto finish; } + memset(out, 0, outSize); for (size_t i = 0; i < dimensions; i++) { out[i] = ((i8 *)a)[i] + ((i8 *)b)[i]; } @@ -1249,6 +1261,7 @@ static void vec_sub(sqlite3_context *context, int argc, sqlite3_value **argv) { sqlite3_result_error_nomem(context); goto finish; } + memset(out, 0, outSize); for (size_t i = 0; i < dimensions; i++) { out[i] = ((f32 *)a)[i] - ((f32 *)b)[i]; } @@ -1263,6 +1276,7 @@ static void vec_sub(sqlite3_context *context, int argc, sqlite3_value **argv) { sqlite3_result_error_nomem(context); goto finish; } + memset(out, 0, outSize); for (size_t i = 0; i < dimensions; i++) { out[i] = ((i8 *)a)[i] - ((i8 *)b)[i]; } @@ -1328,11 +1342,13 @@ static void vec_slice(sqlite3_context *context, int argc, switch (elementType) { case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { - f32 *out = sqlite3_malloc(n * sizeof(f32)); + int outSize = n * sizeof(f32); + f32 *out = sqlite3_malloc(outSize); if (!out) { sqlite3_result_error_nomem(context); return; } + memset(out, 0, outSize); for (size_t i = 0; i < n; i++) { out[i] = ((f32 *)vector)[start + i]; } @@ -1341,11 +1357,13 @@ static void vec_slice(sqlite3_context *context, int argc, goto done; } case SQLITE_VEC_ELEMENT_TYPE_INT8: { - i8 *out = sqlite3_malloc(n * sizeof(i8)); + int outSize = n * sizeof(i8); + i8 *out = sqlite3_malloc(outSize); if (!out) { sqlite3_result_error_nomem(context); return; } + memset(out, 0, outSize); for (size_t i = 0; i < n; i++) { out[i] = ((i8 *)vector)[start + i]; } @@ -1362,12 +1380,13 @@ static void vec_slice(sqlite3_context *context, int argc, sqlite3_result_error(context, "end index must be divisible by 8.", -1); goto done; } - - u8 *out = sqlite3_malloc(n / CHAR_BIT); + int outSize = n / CHAR_BIT; + u8 *out = sqlite3_malloc(outSize); if (!out) { sqlite3_result_error_nomem(context); return; } + memset(out, 0, outSize); for (size_t i = 0; i < n / CHAR_BIT; i++) { out[i] = ((u8 *)vector)[(start / CHAR_BIT) + i]; } @@ -1454,12 +1473,14 @@ static void vec_normalize(sqlite3_context *context, int argc, return; } - f32 *out = sqlite3_malloc(dimensions * sizeof(f32)); + int outSize = dimensions * sizeof(f32); + f32 *out = sqlite3_malloc(outSize); if (!out) { cleanup(vector); sqlite3_result_error_code(context, SQLITE_NOMEM); return; } + memset(out, 0, outSize); f32 *v = (f32 *)vector; @@ -5618,6 +5639,7 @@ static void vec_static_blob_from_raw(sqlite3_context *context, int argc, sqlite3_result_error_nomem(context); return; } + memset(p, 0, sizeof(*p)); p->p = sqlite3_value_int64(argv[0]); p->element_type = SQLITE_VEC_ELEMENT_TYPE_FLOAT32; p->dimensions = sqlite3_value_int64(argv[2]);