From 9730d9b9f2e56e492c518205c3bbf7df200a3b7f Mon Sep 17 00:00:00 2001 From: jue-yuan Date: Thu, 5 Dec 2024 19:43:59 +0000 Subject: [PATCH] [GLE-8861] address comments; --- gds/vector/cosine_distance.gsql | 13 ++++++++++++- gds/vector/distance.gsql | 20 +++++++++++++++++--- gds/vector/norm.gsql | 13 +++---------- 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/gds/vector/cosine_distance.gsql b/gds/vector/cosine_distance.gsql index 28664c85..76221f4d 100644 --- a/gds/vector/cosine_distance.gsql +++ b/gds/vector/cosine_distance.gsql @@ -28,6 +28,8 @@ CREATE FUNCTION gds.vector.cosine_distance(list list1, list list Exceptions: list_size_mismatch (90000): Raised when the input lists are not of equal size. + zero_divisor(90001); + Raised either list is all zero to avoid zero-divisor issue. Logic Overview: Validates that both input vectors have the same length. @@ -42,6 +44,7 @@ CREATE FUNCTION gds.vector.cosine_distance(list list1, list list */ EXCEPTION list_size_mismatch (90000); + EXCEPTION zero_divisor(90001); ListAccum @@myList1 = list1; ListAccum @@myList2 = list2; @@ -49,8 +52,16 @@ CREATE FUNCTION gds.vector.cosine_distance(list list1, list list RAISE list_size_mismatch ("Two lists provided for gds.vector.cosine_distance have different sizes."); END; - double innerP = inner_product(@@myList1, @@myList2); + double inner_p = inner_product(@@myList1, @@myList2); double v1_magn = sqrt(inner_product(@@myList1, @@myList1)); double v2_magn = sqrt(inner_product(@@myList2, @@myList2)); + IF (abs(v1_magn) < 0.0000001) THEN + // use a small positive float to avoid numeric comparison error + RAISE zero_divisor ("The elements in the first list are all zero. It will introduce a zero divisor."); + END; + IF (abs(v1_magn) < 0.0000001) THEN + // use a small positive float to avoid numeric comparison error + RAISE zero_divisor ("The elements in the second list are all zero. It will introduce a zero divisor."); + END; RETURN (1 - innerP / (v1_magn * v2_magn)); } \ No newline at end of file diff --git a/gds/vector/distance.gsql b/gds/vector/distance.gsql index 06d4df37..b17af378 100644 --- a/gds/vector/distance.gsql +++ b/gds/vector/distance.gsql @@ -32,7 +32,9 @@ CREATE FUNCTION gds.vector.distance(list list1, list list2, stri Exceptions: list_size_mismatch (90000): Raised when the input vectors are not of equal size. - invalid_metric_type (90001): + zero_divisor(90001); + Raised either list is all zero to avoid zero-divisor issue. + invalid_metric_type (90002): Raised when an unsupported distance metric is provided. Logic Overview: @@ -55,7 +57,8 @@ CREATE FUNCTION gds.vector.distance(list list1, list list2, stri */ EXCEPTION list_size_mismatch (90000); - EXCEPTION invalid_metric_type (90001); + EXCEPTION zero_divisor(90001); + EXCEPTION invalid_metric_type (90002); ListAccum @@myList1 = list1; ListAccum @@myList2 = list2; @@ -68,7 +71,18 @@ CREATE FUNCTION gds.vector.distance(list list1, list list2, stri CASE lower(metric) WHEN "cosine" THEN - @@myResult = 1 - inner_product(@@myList1, @@myList2) / (sqrt(inner_product(@@myList1, @@myList1)) * sqrt(inner_product(@@myList2, @@myList2))); + double inner_p = inner_product(@@myList1, @@myList2); + double v1_magn = sqrt(inner_product(@@myList1, @@myList1)); + double v2_magn = sqrt(inner_product(@@myList2, @@myList2)); + IF (abs(v1_magn) < 0.0000001) THEN + // use a small positive float to avoid numeric comparison error + RAISE zero_divisor ("The elements in the first list are all zero. It will introduce a zero divisor."); + END; + IF (abs(v2_magn) < 0.0000001) THEN + // use a small positive float to avoid numeric comparison error + RAISE zero_divisor ("The elements in the second list are all zero. It will introduce a zero divisor."); + END; + @@myResult = 1 - inner_p / (v1_magn * v2_magn); WHEN "l2" THEN FOREACH i IN RANGE [0, @@myList1.size() - 1 ] DO @@sqrSum += (@@myList1.get(i) - @@myList2.get(i)) * (@@myList1.get(i) - @@myList2.get(i)); diff --git a/gds/vector/norm.gsql b/gds/vector/norm.gsql index 49e60bb2..feee0e1b 100644 --- a/gds/vector/norm.gsql +++ b/gds/vector/norm.gsql @@ -53,23 +53,16 @@ CREATE FUNCTION gds.vector.norm(list list1, string metric) RETURNS(float EXCEPTION invalid_metric_type (90001); ListAccum @@myList1 = list1; - ListAccum @@myList2; - - FOREACH i IN RANGE [0, @@myList1.size() - 1] DO - @@myList2 += 0; - end; SumAccum @@myResult; SumAccum @@sqrSum; CASE lower(metric) WHEN "l2" THEN - FOREACH i IN RANGE [0, @@myList1.size() - 1 ] DO - @@sqrSum += (@@myList1.get(i) - @@myList2.get(i)) * (@@myList1.get(i) - @@myList2.get(i)); - END; - @@myResult = sqrt(@@sqrSum); + @@myResult = sqrt(inner_product(@@myList1, @@myList1)); WHEN "ip" THEN - @@myResult = inner_product(@@myList1, @@myList2); + // the result of inner product between any vector and all-zero vector should always be 0 + @@myResult = 0; ELSE RAISE invalid_metric_type ("Invalid metric algorithm provided, currently supported: l2 and ip."); END