diff --git a/jaxopt/_src/projection.py b/jaxopt/_src/projection.py index e9dce7bd..04ab9bf3 100644 --- a/jaxopt/_src/projection.py +++ b/jaxopt/_src/projection.py @@ -392,13 +392,18 @@ def projection_box_section(x: jnp.ndarray, def _max_l2(x, marginal_b, gamma): scale = gamma * marginal_b - p = projection_simplex(x / scale) + x_scale = x / scale + p = projection_simplex(x_scale) + # From Danskin's theorem, we do not need to backpropagate + # through projection_simplex. + p = jax.lax.stop_gradient(p) return jnp.dot(x, p) - 0.5 * scale * jnp.dot(p, p) def _max_ent(x, marginal_b, gamma): return gamma * logsumexp(x / gamma) - gamma * jnp.log(marginal_b) + _max_l2_vmap = jax.vmap(_max_l2, in_axes=(1, 0, None)) _max_l2_grad_vmap = jax.vmap(jax.grad(_max_l2), in_axes=(1, 0, None)) @@ -771,4 +776,4 @@ def kl_projection_birkhoff(sim_matrix: jnp.ndarray, return kl_projection_transport(sim_matrix=sim_matrix, marginals=(marginals_a, marginals_b), make_solver=make_solver, - use_semi_dual=use_semi_dual) \ No newline at end of file + use_semi_dual=use_semi_dual)