diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index c5fa3454..a8838a51 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -1101,7 +1101,7 @@ def probabilistic_swap( prob = jax.random.uniform(key1) return lax.cond( - prob < 1 / 2 * (1 - alpha / a), + prob > 1 / 2 * (1 - alpha / a), lambda _: (2 * i, 2 * i + 1), # first case: val1 = x1, val2 = x2 lambda _: (2 * i + 1, 2 * i), # second case: val1 = x2, val2 = x1 None, diff --git a/tests/unit/test_solvers.py b/tests/unit/test_solvers.py index 5bcff47b..ed33906f 100644 --- a/tests/unit/test_solvers.py +++ b/tests/unit/test_solvers.py @@ -2179,8 +2179,8 @@ def test_kt_half_analytic(self) -> None: r""" Test the halving step of kernel thinning on analytical example. - We aim to split [1, 2, 3, 4, 5, 6, 7, 8] into two coresets, S1 and S2, each - containing 4 elements, enforcing two unique coresets. + We aim to split [0.7,0.55,0.6,0.65,0.9,0.10,0.11,0.12] into two coresets, S1 + and S2, each containing 4 elements, enforcing two unique coresets. First, let S be the full dataset, with S1 and S2 as subsets. S1 will contain half the elements, and S2 will contain the other half. Let :math:`k` represent @@ -2188,12 +2188,12 @@ def test_kt_half_analytic(self) -> None: :math:`\\alpha`, :math:`\\sigma`, :math:`\\delta`, and probability, which will be updated iteratively to form the coresets. - We process pairs :math:`(x, y)` sequentially: :math:`(1, 2)`, :math:`(3, 4)`, - :math:`(5, 6)`, and :math:`(7, 8)`. For each pair, we compute a probability + We process pairs :math:`(x, y)` sequentially: :math:`(0.7, 0.55)`, then + :math:`(0.6, 0.65)`, and so on. For each pair, we compute a swap probability that determines whether :math:`x` goes to S1 and :math:`y` to S2, or vice versa. In either case, both :math:`x` and :math:`y` are added to S. - If this probability is greater than 0.5, we add the x and y to S1 and S2 + If swap probability is less than 0.5, we add the x and y to S1 and S2 respectively, otherwise we swap x and y and then add x to S1 and y to S2. The process is as follows: @@ -2245,78 +2245,88 @@ def test_kt_half_analytic(self) -> None: Calculations for each pair: - Pair (1, 2): + **Pair (0.7, 0.55):** + - Inputs: S=[], S1=[], S2=[], sigma=0, delta=1/8. - Compute b: - b(1,2) = sqrt(k(1,1) + k(2,2) - 2*k(1,2)) = 1.1243847608566284. + - b(0.7, 0.55) = 0.2109442800283432. - Compute alpha: alpha = 0 (as S and S1 are empty). - Compute a: - a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 1.264241099357605. + - a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 0.04449748992919922. - Update sigma: - new_sigma^2 = sigma^2 + max(0, b^2 * (1 + (b^2 - 2*a) * sigma^2) / a^2). - new_sigma = sqrt(new_sigma^2) = 1.1243847608566284. + - new_sigma = 0.2109442800283432. - Compute probability: - p = 0.5 * (1 - alpha / a) = 0.5. + - p = 0.5 * (1 - alpha / a) = 0.5. - Assign: - Since p <= 0.5, assign x=1 to S2, y=2 to S1, and add both to S. - S1 = [2], S2 = [1], S = [1, 2]. + - Since p >= 0.5, assign x=0.7 to S2, y=0.55 to S1, and add both to S. + - S1 = [0.55], S2 = [0.7], S = [0.7, 0.55]. + + --- - Pair (3, 4): - - Inputs: S=[1, 2], S1=[2], S2=[1], sigma=1.1243847608566284. + **Pair (0.6, 0.65):** + + - Inputs: S=[0.7, 0.55], S1=[0.55], S2=[0.7], sigma=0.2109442800283432. - Compute b: - b(3,4) = sqrt(k(3,3) + k(4,4) - 2*k(3,4)) = 1.1243847608566284. + - b(0.6, 0.65) = 0.07066679745912552. - Compute alpha: - alpha = sum(k(s, 3) - k(s, 4) for s in S) - 2 * sum(k(s, 3) - k(s, 4) for s in S1). - alpha = -0.3313715159893036. + - alpha = -0.014906525611877441. - Compute a: - a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 2.9770602825192523. + - a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 0.035102729200688874. - Update sigma: - new_sigma = sqrt(sigma^2 + max(0, b^2 * (1 + (b^2 - 2*a) * sigma^2) / a^2)). - new_sigma = 1.297198507467962. + - new_sigma = 0.2109442800283432. - Compute probability: - p = 0.5 * (1 - alpha / a) = 0.5556541681289673. + - p = 0.7123271822929382. - Assign: - Since p > 0.5, assign x=3 to S1 and y=4 to S2, and add both to S. - S1 = [2, 3], S2 = [1, 4], S = [1, 2, 3, 4]. + - Since p > 0.5, assign x=0.6 to S2 and y=0.65 to S1, and add both to S. + - S1 = [0.55, 0.65], S2 = [0.7, 0.6], S = [0.7, 0.55, 0.6, 0.65]. + + --- - Pair (5, 6): - - Inputs: S=[1, 2, 3, 4], S1=[2, 3], S2=[1, 4], sigma=1.297198507467962. + **Pair (0.9, 0.1):** + + - Inputs: S=[0.7, 0.55, 0.6, 0.65], S1=[0.55, 0.65], S2=[0.7, 0.6], + sigma=0.2109442800283432. - Compute b: - b(5,6) = sqrt(k(5,5) + k(6,6) - 2*k(5,6)) = 1.1243847608566284. + - b(0.9, 0.1) = 0.9723246097564697. - Compute alpha: - alpha = sum(k(s, 5) - k(s, 6) for s in S) - 2 * sum(k(s, 5) - k(s, 6) for s in S1). - alpha = 0.33124834299087524. + - alpha = 0.12977957725524902. - Compute a: - a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 3.434623326772776. + - a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 0.9454151391983032. - Update sigma: - new_sigma = sqrt(sigma^2 + max(0, b^2 * (1 + (b^2 - 2*a) * sigma^2) / a^2)). - new_sigma = 1.3914653590235087. + - new_sigma = 0.9723246097564697. - Compute probability: - p = 0.5 * (1 - alpha / a) = 0.4517780542373657. + - p = 0.43136370182037354. - Assign: - Since p <= 0.5, assign x=5 to S2 and y=6 to S1, and add both to S. - S1 = [2, 3, 6], S2 = [1, 4, 5], S = [1, 2, 3, 4, 5, 6]. + - Since p < 0.5, assign x=0.9 to S1 and y=0.1 to S2, and add both to S. + - S1 = [0.55, 0.65, 0.9], S2 = [0.7, 0.6, 0.1], + S = [0.7, 0.55, 0.6, 0.65, 0.9, 0.1]. + + --- + + **Pair (0.11, 0.12):** - Pair (7, 8): - - Inputs: S=[1, 2, 3, 4, 5, 6], S1=[2, 3, 6], S2=[1, 4, 5], sigma=1.3914653590235087. + - Inputs: S=[0.7, 0.55, 0.6, 0.65, 0.9, 0.1], S1=[0.55, 0.65, 0.9], + S2=[0.7, 0.6, 0.1], sigma=0.9723246097564697. - Compute b: - b(7,8) = sqrt(k(7,7) + k(8,8) - 2*k(7,8)) = 1.1243847608566284. + - b(0.11, 0.12) = 0.014143308624625206. - Compute alpha: - alpha = sum(k(s, 7) - k(s, 8) for s in S) - 2 * sum(k(s, 7) - k(s, 8) for s in S1). - alpha = -0.33124834299087524. + - alpha = 0.008038222789764404. - Compute a: - a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 3.6842159106604075. + - a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 0.03238321865838572. - Update sigma: - new_sigma = sqrt(sigma^2 + max(0, b^2 * (1 + (b^2 - 2*a) * sigma^2) / a^2)). - new_sigma = 1.4490018035043584. + - new_sigma = 0.9723246097564697. - Compute probability: - p = 0.5 * (1 - alpha / a) = 0.5449550747871399. + - p = 0.3758890628814697. - Assign: - Since p > 0.5, assign x=7 to S1 and y=8 to S2, and add both to S. - S1 = [2, 3, 6, 7], S2 = [1, 4, 5, 8], S = [1, 2, 3, 4, 5, 6, 7, 8]. + - Since p < 0.5, assign x=0.11 to S1 and y=0.12 to S2, and add both to S. + - S1 = [0.55, 0.65, 0.9, 0.11], S2 = [0.7, 0.6, 0.1, 0.12], + S = [0.7, 0.55, 0.6, 0.65, 0.9, 0.1, 0.11, 0.12]. - Final result: - S1 = [2, 3, 6, 7], S2 = [1, 4, 5, 8]. + --- + + **Final result:** + + S1 = [0.55, 0.65, 0.9, 0.11], S2 = [0.7, 0.6, 0.1, 0.12]. """ # noqa: E501 # pylint: enable=line-too-long length_scale = 1.0 / jnp.sqrt(2) @@ -2324,7 +2334,7 @@ def test_kt_half_analytic(self) -> None: sqrt_kernel = SquaredExponentialKernel(length_scale=length_scale) delta = 1 / 8 random_key = jax.random.PRNGKey(seed=0) - data = Data(jnp.array([1, 2, 3, 4, 5, 6, 7, 8])) + data = Data(jnp.array([0.7, 0.55, 0.6, 0.65, 0.9, 0.10, 0.11, 0.12])) thinning_solver = KernelThinning( coreset_size=2, kernel=kernel, @@ -2342,5 +2352,9 @@ def deterministic_uniform(_key, _shape=None): jnp.asarray(s.coreset.data) for s in thinning_solver.kt_half(data) ] - np.testing.assert_array_equal(coresets[0], jnp.array([[2], [3], [6], [7]])) - np.testing.assert_array_equal(coresets[1], jnp.array([[1], [4], [5], [8]])) + np.testing.assert_array_equal( + coresets[0], jnp.array([[0.55], [0.65], [0.9], [0.11]]) + ) + np.testing.assert_array_equal( + coresets[1], jnp.array([[0.7], [0.6], [0.1], [0.12]]) + )