Skip to content

Commit

Permalink
memset 0 all applicable mallocs, fix windows?
Browse files Browse the repository at this point in the history
  • Loading branch information
asg017 committed Jun 29, 2024
1 parent da8a19f commit 44aef7a
Showing 1 changed file with 33 additions and 11 deletions.
44 changes: 33 additions & 11 deletions sqlite-vec.c
Original file line number Diff line number Diff line change
Expand Up @@ -431,10 +431,13 @@ struct Array {
* SQLITE_NOMEM
*/
int array_init(struct Array *array, size_t element_size, size_t init_capacity) {
void *z = sqlite3_malloc(element_size * init_capacity);
int sz = element_size * init_capacity;
void *z = sqlite3_malloc(sz);
if (!z) {
return SQLITE_NOMEM;
}
memset(z, 0, sz);

array->element_size = element_size;
array->length = 0;
array->capacity = init_capacity;
Expand Down Expand Up @@ -870,6 +873,7 @@ static void vec_npy_file(sqlite3_context *context, int argc,
sqlite3_result_error_nomem(context);
return;
}
memset(f, 0, sizeof(*f));

f->path = path;
f->pathLength = pathLength;
Expand Down Expand Up @@ -1081,9 +1085,11 @@ static void vec_quantize_i8(sqlite3_context *context, int argc,
fvec_cleanup cleanup;
char *err;
int rc = fvec_from_value(argv[0], &srcVector, &dimensions, &cleanup, &err);
assert(rc == SQLITE_OK);
i8 *out = sqlite3_malloc(dimensions * sizeof(i8));
assert(out);
todo_assert(rc == SQLITE_OK);
int sz = dimensions * sizeof(i8);
i8 *out = sqlite3_malloc(sz);
todo_assert(out);
memset(out, 0, sz);

if (argc == 2) {
if ((sqlite3_value_type(argv[1]) != SQLITE_TEXT) ||
Expand Down Expand Up @@ -1133,25 +1139,29 @@ static void vec_quantize_binary(sqlite3_context *context, int argc,
}

if (elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
u8 *out = sqlite3_malloc(dimensions / CHAR_BIT);
int sz = dimensions / CHAR_BIT;
u8 *out = sqlite3_malloc(sz);
if (!out) {
cleanup(vector);
sqlite3_result_error_code(context, SQLITE_NOMEM);
return;
}
memset(out, 0, sz);
for (size_t i = 0; i < dimensions; i++) {
int res = ((f32 *)vector)[i] > 0.0;
out[i / 8] |= (res << (i % 8));
}
sqlite3_result_blob(context, out, dimensions / CHAR_BIT, sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT);
} else if (elementType == SQLITE_VEC_ELEMENT_TYPE_INT8) {
u8 *out = sqlite3_malloc(dimensions / CHAR_BIT);
int sz = dimensions / CHAR_BIT;
u8 *out = sqlite3_malloc(sz);
if (!out) {
cleanup(vector);
sqlite3_result_error_code(context, SQLITE_NOMEM);
return;
}
memset(out, 0, sz);
for (size_t i = 0; i < dimensions; i++) {
int res = ((i8 *)vector)[i] > 0;
out[i / 8] |= (res << (i % 8));
Expand Down Expand Up @@ -1193,6 +1203,7 @@ static void vec_add(sqlite3_context *context, int argc, sqlite3_value **argv) {
sqlite3_result_error_nomem(context);
goto finish;
}
memset(out, 0, outSize);
for (size_t i = 0; i < dimensions; i++) {
out[i] = ((f32 *)a)[i] + ((f32 *)b)[i];
}
Expand All @@ -1207,6 +1218,7 @@ static void vec_add(sqlite3_context *context, int argc, sqlite3_value **argv) {
sqlite3_result_error_nomem(context);
goto finish;
}
memset(out, 0, outSize);
for (size_t i = 0; i < dimensions; i++) {
out[i] = ((i8 *)a)[i] + ((i8 *)b)[i];
}
Expand Down Expand Up @@ -1249,6 +1261,7 @@ static void vec_sub(sqlite3_context *context, int argc, sqlite3_value **argv) {
sqlite3_result_error_nomem(context);
goto finish;
}
memset(out, 0, outSize);
for (size_t i = 0; i < dimensions; i++) {
out[i] = ((f32 *)a)[i] - ((f32 *)b)[i];
}
Expand All @@ -1263,6 +1276,7 @@ static void vec_sub(sqlite3_context *context, int argc, sqlite3_value **argv) {
sqlite3_result_error_nomem(context);
goto finish;
}
memset(out, 0, outSize);
for (size_t i = 0; i < dimensions; i++) {
out[i] = ((i8 *)a)[i] - ((i8 *)b)[i];
}
Expand Down Expand Up @@ -1328,11 +1342,13 @@ static void vec_slice(sqlite3_context *context, int argc,

switch (elementType) {
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: {
f32 *out = sqlite3_malloc(n * sizeof(f32));
int outSize = n * sizeof(f32);
f32 *out = sqlite3_malloc(outSize);
if (!out) {
sqlite3_result_error_nomem(context);
return;
}
memset(out, 0, outSize);
for (size_t i = 0; i < n; i++) {
out[i] = ((f32 *)vector)[start + i];
}
Expand All @@ -1341,11 +1357,13 @@ static void vec_slice(sqlite3_context *context, int argc,
goto done;
}
case SQLITE_VEC_ELEMENT_TYPE_INT8: {
i8 *out = sqlite3_malloc(n * sizeof(i8));
int outSize = n * sizeof(i8);
i8 *out = sqlite3_malloc(outSize);
if (!out) {
sqlite3_result_error_nomem(context);
return;
}
memset(out, 0, outSize);
for (size_t i = 0; i < n; i++) {
out[i] = ((i8 *)vector)[start + i];
}
Expand All @@ -1362,12 +1380,13 @@ static void vec_slice(sqlite3_context *context, int argc,
sqlite3_result_error(context, "end index must be divisible by 8.", -1);
goto done;
}

u8 *out = sqlite3_malloc(n / CHAR_BIT);
int outSize = n / CHAR_BIT;
u8 *out = sqlite3_malloc(outSize);
if (!out) {
sqlite3_result_error_nomem(context);
return;
}
memset(out, 0, outSize);
for (size_t i = 0; i < n / CHAR_BIT; i++) {
out[i] = ((u8 *)vector)[(start / CHAR_BIT) + i];
}
Expand Down Expand Up @@ -1454,12 +1473,14 @@ static void vec_normalize(sqlite3_context *context, int argc,
return;
}

f32 *out = sqlite3_malloc(dimensions * sizeof(f32));
int outSize = dimensions * sizeof(f32);
f32 *out = sqlite3_malloc(outSize);
if (!out) {
cleanup(vector);
sqlite3_result_error_code(context, SQLITE_NOMEM);
return;
}
memset(out, 0, outSize);

f32 *v = (f32 *)vector;

Expand Down Expand Up @@ -5618,6 +5639,7 @@ static void vec_static_blob_from_raw(sqlite3_context *context, int argc,
sqlite3_result_error_nomem(context);
return;
}
memset(p, 0, sizeof(*p));
p->p = sqlite3_value_int64(argv[0]);
p->element_type = SQLITE_VEC_ELEMENT_TYPE_FLOAT32;
p->dimensions = sqlite3_value_int64(argv[2]);
Expand Down

0 comments on commit 44aef7a

Please sign in to comment.