From db28f4bce5240220528eb34533775d4e515a66cf Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Thu, 21 Nov 2024 13:05:50 +0100 Subject: [PATCH] [MRG] Sinkhorn gradient last step (#693) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * change solver * test * update test * Update ot/solvers.py Co-authored-by: Rémi Flamary * update doc * add test for max_iter * fix bug on gradients * update RELEASES.md * update comment * add detach and comment * add example * add test for detach * fix example * delete unused importations in example * move example to backend * reduce n_trials for example --------- Co-authored-by: Rémi Flamary --- RELEASES.md | 1 + examples/backends/plot_Sinkhorn_gradients.py | 85 +++++++++++++++++++ ot/solvers.py | 37 +++++++-- test/test_solvers.py | 87 +++++++++++++++++++- 4 files changed, 203 insertions(+), 7 deletions(-) create mode 100644 examples/backends/plot_Sinkhorn_gradients.py diff --git a/RELEASES.md b/RELEASES.md index 1caaa04b4..e74b44507 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -4,6 +4,7 @@ #### New features - Implement CG solvers for partial FGW (PR #687) +- Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) diff --git a/examples/backends/plot_Sinkhorn_gradients.py b/examples/backends/plot_Sinkhorn_gradients.py new file mode 100644 index 000000000..229a952b1 --- /dev/null +++ b/examples/backends/plot_Sinkhorn_gradients.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +""" +================================================ +Different gradient computations for regularized optimal transport +================================================ + +This example illustrates the differences in terms of computation time between the gradient options for the Sinkhorn solver. + +""" + +# Author: Sonia Mazelet +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +import matplotlib.pylab as pl +import ot +from ot.backend import torch + + +############################################################################## +# Time comparison of the Sinkhorn solver for different gradient options +# ------------- + + +# %% parameters + +n_trials = 10 +times_autodiff = torch.zeros(n_trials) +times_envelope = torch.zeros(n_trials) +times_last_step = torch.zeros(n_trials) + +n_samples_s = 300 +n_samples_t = 300 +n_features = 5 +reg = 0.03 + +# Time required for the Sinkhorn solver and gradient computations, for different gradient options over multiple Gaussian distributions +for i in range(n_trials): + x = torch.rand((n_samples_s, n_features)) + y = torch.rand((n_samples_t, n_features)) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + M = ot.dist(x, y) + + a = torch.tensor(a, requires_grad=True) + b = torch.tensor(b, requires_grad=True) + M = M.clone().detach().requires_grad_(True) + + # autodiff provides the gradient for all the outputs (plan, value, value_linear) + ot.tic() + res_autodiff = ot.solve(M, a, b, reg=reg, grad="autodiff") + res_autodiff.value.backward() + times_autodiff[i] = ot.toq() + + a = a.clone().detach().requires_grad_(True) + b = b.clone().detach().requires_grad_(True) + M = M.clone().detach().requires_grad_(True) + + # envelope provides the gradient for value + ot.tic() + res_envelope = ot.solve(M, a, b, reg=reg, grad="envelope") + res_envelope.value.backward() + times_envelope[i] = ot.toq() + + a = a.clone().detach().requires_grad_(True) + b = b.clone().detach().requires_grad_(True) + M = M.clone().detach().requires_grad_(True) + + # last_step provides the gradient for all the outputs, but only for the last iteration of the Sinkhorn algorithm + ot.tic() + res_last_step = ot.solve(M, a, b, reg=reg, grad="last_step") + res_last_step.value.backward() + times_last_step[i] = ot.toq() + +pl.figure(1, figsize=(5, 3)) +pl.ticklabel_format(axis="y", style="sci", scilimits=(0, 0)) +pl.boxplot( + ([times_autodiff, times_envelope, times_last_step]), + tick_labels=["autodiff", "envelope", "last_step"], + showfliers=False, +) +pl.ylabel("Time (s)") +pl.show() diff --git a/ot/solvers.py b/ot/solvers.py index 508f248d5..80f366354 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -125,11 +125,13 @@ def solve( verbose : bool, optional Print information in the solver, by default False grad : str, optional - Type of gradient computation, either or 'autodiff' or 'envelope' used only for + Type of gradient computation, either or 'autodiff', 'envelope' or 'last_step' used only for Sinkhorn solver. By default 'autodiff' provides gradients wrt all outputs (`plan, value, value_linear`) but with important memory cost. 'envelope' provides gradients only for `value` and and other outputs are - detached. This is useful for memory saving when only the value is needed. + detached. This is useful for memory saving when only the value is needed. 'last_step' provides + gradients only for the last iteration of the Sinkhorn solver, but provides gradient for both the OT plan and the objective values. + 'detach' does not compute the gradients for the Sinkhorn solver. Returns ------- @@ -281,7 +283,6 @@ def solve( linear regression. NeurIPS. """ - # detect backend nx = get_backend(M, a, b, c) @@ -412,7 +413,11 @@ def solve( potentials = (log["u"], log["v"]) elif reg_type.lower() in ["entropy", "kl"]: - if grad == "envelope": # if envelope then detach the input + if grad in [ + "envelope", + "last_step", + "detach", + ]: # if envelope, last_step or detach then detach the input M0, a0, b0 = M, a, b M, a, b = nx.detach(M, a, b) @@ -421,6 +426,12 @@ def solve( max_iter = 1000 if tol is None: tol = 1e-9 + if grad == "last_step": + if max_iter == 0: + raise ValueError( + "The maximum number of iterations must be greater than 0 when using grad=last_step." + ) + max_iter = max_iter - 1 plan, log = sinkhorn_log( a, @@ -433,6 +444,22 @@ def solve( verbose=verbose, ) + potentials = (log["log_u"], log["log_v"]) + + # if last_step, compute the last step of the Sinkhorn algorithm with the non-detached inputs + if grad == "last_step": + loga = nx.log(a0) + logb = nx.log(b0) + v = logb - nx.logsumexp(-M0 / reg + potentials[0][:, None], 0) + u = loga - nx.logsumexp(-M0 / reg + potentials[1][None, :], 1) + plan = nx.exp(-M0 / reg + u[:, None] + v[None, :]) + potentials = (u, v) + log["niter"] = max_iter + 1 + log["log_u"] = u + log["log_v"] = v + log["u"] = nx.exp(u) + log["v"] = nx.exp(v) + value_linear = nx.sum(M * plan) if reg_type.lower() == "entropy": @@ -442,8 +469,6 @@ def solve( plan, a[:, None] * b[None, :] ) - potentials = (log["log_u"], log["log_v"]) - if grad == "envelope": # set the gradient at convergence value = nx.set_gradients( value, diff --git a/test/test_solvers.py b/test/test_solvers.py index f6077e005..85852aca6 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -143,6 +143,91 @@ def test_solve(nx): sol0 = ot.solve(M, reg=1, reg_type="cryptic divergence") +@pytest.mark.skipif(not torch, reason="torch no installed") +def test_solve_last_step(): + n_samples_s = 10 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + M = ot.dist(x, y) + + # Check that last_step and autodiff give the same result and similar gradients + a = torch.tensor(a, requires_grad=True) + b = torch.tensor(b, requires_grad=True) + M = torch.tensor(M, requires_grad=True) + + sol0 = ot.solve(M, a, b, reg=10, grad="autodiff") + sol0.value.backward() + + gM0 = M.grad.clone() + ga0 = a.grad.clone() + gb0 = b.grad.clone() + + a = torch.tensor(a, requires_grad=True) + b = torch.tensor(b, requires_grad=True) + M = torch.tensor(M, requires_grad=True) + + sol = ot.solve(M, a, b, reg=10, grad="last_step") + sol.value.backward() + + gM = M.grad.clone() + ga = a.grad.clone() + gb = b.grad.clone() + + # Note, gradients are invariant to change in constant so we center them + cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6) + tolerance = 0.96 + assert cos(gM0.flatten(), gM.flatten()) > tolerance + assert cos(ga0 - ga0.mean(), ga - ga.mean()) > tolerance + assert cos(gb0 - gb0.mean(), gb - gb.mean()) > tolerance + + assert torch.allclose(sol0.plan, sol.plan) + assert torch.allclose(sol0.value, sol.value) + assert torch.allclose(sol0.value_linear, sol.value_linear) + assert torch.allclose(sol0.potentials[0], sol.potentials[0]) + assert torch.allclose(sol0.potentials[1], sol.potentials[1]) + + with pytest.raises(ValueError): + ot.solve(M, a, b, grad="last_step", max_iter=0, reg=10) + + +@pytest.mark.skipif(not torch, reason="torch no installed") +def test_solve_detach(): + n_samples_s = 10 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + M = ot.dist(x, y) + + # Check that last_step and autodiff give the same result and similar gradients + a = torch.tensor(a, requires_grad=True) + b = torch.tensor(b, requires_grad=True) + M = torch.tensor(M, requires_grad=True) + + sol0 = ot.solve(M, a, b, reg=10, grad="detach") + + with pytest.raises(RuntimeError): + sol0.value.backward() + + sol = ot.solve(M, a, b, reg=10, grad="autodiff") + + assert torch.allclose(sol0.plan, sol.plan) + assert torch.allclose(sol0.value, sol.value) + assert torch.allclose(sol0.value_linear, sol.value_linear) + assert torch.allclose(sol0.potentials[0], sol.potentials[0]) + assert torch.allclose(sol0.potentials[1], sol.potentials[1]) + + @pytest.mark.skipif(not torch, reason="torch no installed") def test_solve_envelope(): n_samples_s = 10 @@ -178,7 +263,7 @@ def test_solve_envelope(): ga = a.grad.clone() gb = b.grad.clone() - # Note, gradients aer invariant to change in constant so we center them + # Note, gradients are invariant to change in constant so we center them assert torch.allclose(gM0, gM) assert torch.allclose(ga0 - ga0.mean(), ga - ga.mean()) assert torch.allclose(gb0 - gb0.mean(), gb - gb.mean())