Skip to content

Commit

Permalink
Merge pull request #932 from gchq/fix/kt-probability-fix
Browse files Browse the repository at this point in the history
feat: fixed swap probability computation for kernel thinning
  • Loading branch information
gw265981 authored Jan 20, 2025
2 parents b1c83a4 + 6cbfc4e commit 8308711
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 52 deletions.
2 changes: 1 addition & 1 deletion coreax/solvers/coresubset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,7 @@ def probabilistic_swap(

prob = jax.random.uniform(key1)
return lax.cond(
prob < 1 / 2 * (1 - alpha / a),
prob > 1 / 2 * (1 - alpha / a),
lambda _: (2 * i, 2 * i + 1), # first case: val1 = x1, val2 = x2
lambda _: (2 * i + 1, 2 * i), # second case: val1 = x2, val2 = x1
None,
Expand Down
116 changes: 65 additions & 51 deletions tests/unit/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2179,21 +2179,21 @@ def test_kt_half_analytic(self) -> None:
r"""
Test the halving step of kernel thinning on analytical example.
We aim to split [1, 2, 3, 4, 5, 6, 7, 8] into two coresets, S1 and S2, each
containing 4 elements, enforcing two unique coresets.
We aim to split [0.7,0.55,0.6,0.65,0.9,0.10,0.11,0.12] into two coresets, S1
and S2, each containing 4 elements, enforcing two unique coresets.
First, let S be the full dataset, with S1 and S2 as subsets. S1 will contain
half the elements, and S2 will contain the other half. Let :math:`k` represent
the square root kernel. We will use variables labelled :math:`a`, :math:`b`,
:math:`\\alpha`, :math:`\\sigma`, :math:`\\delta`, and probability, which
will be updated iteratively to form the coresets.
We process pairs :math:`(x, y)` sequentially: :math:`(1, 2)`, :math:`(3, 4)`,
:math:`(5, 6)`, and :math:`(7, 8)`. For each pair, we compute a probability
We process pairs :math:`(x, y)` sequentially: :math:`(0.7, 0.55)`, then
:math:`(0.6, 0.65)`, and so on. For each pair, we compute a swap probability
that determines whether :math:`x` goes to S1 and :math:`y` to S2, or vice
versa. In either case, both :math:`x` and :math:`y` are added to S.
If this probability is greater than 0.5, we add the x and y to S1 and S2
If swap probability is less than 0.5, we add the x and y to S1 and S2
respectively, otherwise we swap x and y and then add x to S1 and y to S2.
The process is as follows:
Expand Down Expand Up @@ -2245,86 +2245,96 @@ def test_kt_half_analytic(self) -> None:
Calculations for each pair:
Pair (1, 2):
**Pair (0.7, 0.55):**
- Inputs: S=[], S1=[], S2=[], sigma=0, delta=1/8.
- Compute b:
b(1,2) = sqrt(k(1,1) + k(2,2) - 2*k(1,2)) = 1.1243847608566284.
- b(0.7, 0.55) = 0.2109442800283432.
- Compute alpha: alpha = 0 (as S and S1 are empty).
- Compute a:
a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 1.264241099357605.
- a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 0.04449748992919922.
- Update sigma:
new_sigma^2 = sigma^2 + max(0, b^2 * (1 + (b^2 - 2*a) * sigma^2) / a^2).
new_sigma = sqrt(new_sigma^2) = 1.1243847608566284.
- new_sigma = 0.2109442800283432.
- Compute probability:
p = 0.5 * (1 - alpha / a) = 0.5.
- p = 0.5 * (1 - alpha / a) = 0.5.
- Assign:
Since p <= 0.5, assign x=1 to S2, y=2 to S1, and add both to S.
S1 = [2], S2 = [1], S = [1, 2].
- Since p >= 0.5, assign x=0.7 to S2, y=0.55 to S1, and add both to S.
- S1 = [0.55], S2 = [0.7], S = [0.7, 0.55].
---
Pair (3, 4):
- Inputs: S=[1, 2], S1=[2], S2=[1], sigma=1.1243847608566284.
**Pair (0.6, 0.65):**
- Inputs: S=[0.7, 0.55], S1=[0.55], S2=[0.7], sigma=0.2109442800283432.
- Compute b:
b(3,4) = sqrt(k(3,3) + k(4,4) - 2*k(3,4)) = 1.1243847608566284.
- b(0.6, 0.65) = 0.07066679745912552.
- Compute alpha:
alpha = sum(k(s, 3) - k(s, 4) for s in S) - 2 * sum(k(s, 3) - k(s, 4) for s in S1).
alpha = -0.3313715159893036.
- alpha = -0.014906525611877441.
- Compute a:
a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 2.9770602825192523.
- a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 0.035102729200688874.
- Update sigma:
new_sigma = sqrt(sigma^2 + max(0, b^2 * (1 + (b^2 - 2*a) * sigma^2) / a^2)).
new_sigma = 1.297198507467962.
- new_sigma = 0.2109442800283432.
- Compute probability:
p = 0.5 * (1 - alpha / a) = 0.5556541681289673.
- p = 0.7123271822929382.
- Assign:
Since p > 0.5, assign x=3 to S1 and y=4 to S2, and add both to S.
S1 = [2, 3], S2 = [1, 4], S = [1, 2, 3, 4].
- Since p > 0.5, assign x=0.6 to S2 and y=0.65 to S1, and add both to S.
- S1 = [0.55, 0.65], S2 = [0.7, 0.6], S = [0.7, 0.55, 0.6, 0.65].
---
Pair (5, 6):
- Inputs: S=[1, 2, 3, 4], S1=[2, 3], S2=[1, 4], sigma=1.297198507467962.
**Pair (0.9, 0.1):**
- Inputs: S=[0.7, 0.55, 0.6, 0.65], S1=[0.55, 0.65], S2=[0.7, 0.6],
sigma=0.2109442800283432.
- Compute b:
b(5,6) = sqrt(k(5,5) + k(6,6) - 2*k(5,6)) = 1.1243847608566284.
- b(0.9, 0.1) = 0.9723246097564697.
- Compute alpha:
alpha = sum(k(s, 5) - k(s, 6) for s in S) - 2 * sum(k(s, 5) - k(s, 6) for s in S1).
alpha = 0.33124834299087524.
- alpha = 0.12977957725524902.
- Compute a:
a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 3.434623326772776.
- a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 0.9454151391983032.
- Update sigma:
new_sigma = sqrt(sigma^2 + max(0, b^2 * (1 + (b^2 - 2*a) * sigma^2) / a^2)).
new_sigma = 1.3914653590235087.
- new_sigma = 0.9723246097564697.
- Compute probability:
p = 0.5 * (1 - alpha / a) = 0.4517780542373657.
- p = 0.43136370182037354.
- Assign:
Since p <= 0.5, assign x=5 to S2 and y=6 to S1, and add both to S.
S1 = [2, 3, 6], S2 = [1, 4, 5], S = [1, 2, 3, 4, 5, 6].
- Since p < 0.5, assign x=0.9 to S1 and y=0.1 to S2, and add both to S.
- S1 = [0.55, 0.65, 0.9], S2 = [0.7, 0.6, 0.1],
S = [0.7, 0.55, 0.6, 0.65, 0.9, 0.1].
---
**Pair (0.11, 0.12):**
Pair (7, 8):
- Inputs: S=[1, 2, 3, 4, 5, 6], S1=[2, 3, 6], S2=[1, 4, 5], sigma=1.3914653590235087.
- Inputs: S=[0.7, 0.55, 0.6, 0.65, 0.9, 0.1], S1=[0.55, 0.65, 0.9],
S2=[0.7, 0.6, 0.1], sigma=0.9723246097564697.
- Compute b:
b(7,8) = sqrt(k(7,7) + k(8,8) - 2*k(7,8)) = 1.1243847608566284.
- b(0.11, 0.12) = 0.014143308624625206.
- Compute alpha:
alpha = sum(k(s, 7) - k(s, 8) for s in S) - 2 * sum(k(s, 7) - k(s, 8) for s in S1).
alpha = -0.33124834299087524.
- alpha = 0.008038222789764404.
- Compute a:
a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 3.6842159106604075.
- a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 0.03238321865838572.
- Update sigma:
new_sigma = sqrt(sigma^2 + max(0, b^2 * (1 + (b^2 - 2*a) * sigma^2) / a^2)).
new_sigma = 1.4490018035043584.
- new_sigma = 0.9723246097564697.
- Compute probability:
p = 0.5 * (1 - alpha / a) = 0.5449550747871399.
- p = 0.3758890628814697.
- Assign:
Since p > 0.5, assign x=7 to S1 and y=8 to S2, and add both to S.
S1 = [2, 3, 6, 7], S2 = [1, 4, 5, 8], S = [1, 2, 3, 4, 5, 6, 7, 8].
- Since p < 0.5, assign x=0.11 to S1 and y=0.12 to S2, and add both to S.
- S1 = [0.55, 0.65, 0.9, 0.11], S2 = [0.7, 0.6, 0.1, 0.12],
S = [0.7, 0.55, 0.6, 0.65, 0.9, 0.1, 0.11, 0.12].
Final result:
S1 = [2, 3, 6, 7], S2 = [1, 4, 5, 8].
---
**Final result:**
S1 = [0.55, 0.65, 0.9, 0.11], S2 = [0.7, 0.6, 0.1, 0.12].
""" # noqa: E501
# pylint: enable=line-too-long
length_scale = 1.0 / jnp.sqrt(2)
kernel = SquaredExponentialKernel()
sqrt_kernel = SquaredExponentialKernel(length_scale=length_scale)
delta = 1 / 8
random_key = jax.random.PRNGKey(seed=0)
data = Data(jnp.array([1, 2, 3, 4, 5, 6, 7, 8]))
data = Data(jnp.array([0.7, 0.55, 0.6, 0.65, 0.9, 0.10, 0.11, 0.12]))
thinning_solver = KernelThinning(
coreset_size=2,
kernel=kernel,
Expand All @@ -2342,5 +2352,9 @@ def deterministic_uniform(_key, _shape=None):
jnp.asarray(s.coreset.data) for s in thinning_solver.kt_half(data)
]

np.testing.assert_array_equal(coresets[0], jnp.array([[2], [3], [6], [7]]))
np.testing.assert_array_equal(coresets[1], jnp.array([[1], [4], [5], [8]]))
np.testing.assert_array_equal(
coresets[0], jnp.array([[0.55], [0.65], [0.9], [0.11]])
)
np.testing.assert_array_equal(
coresets[1], jnp.array([[0.7], [0.6], [0.1], [0.12]])
)

0 comments on commit 8308711

Please sign in to comment.