From adab5951f4aea605f853f64eae8a1b3086be37c0 Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Thu, 7 Nov 2024 15:48:27 +0100 Subject: [PATCH 01/16] change solver --- ot/solvers.py | 32 +++++++++++++++---- test/test_solvers.py | 74 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 99 insertions(+), 7 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index ec56d1330..1133a76ce 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -124,11 +124,12 @@ def solve( verbose : bool, optional Print information in the solver, by default False grad : str, optional - Type of gradient computation, either or 'autodiff' or 'envelope' used only for + Type of gradient computation, either or 'autodiff', 'envelope' or 'last_step' used only for Sinkhorn solver. By default 'autodiff' provides gradients wrt all outputs (`plan, value, value_linear`) but with important memory cost. 'envelope' provides gradients only for `value` and and other outputs are - detached. This is useful for memory saving when only the value is needed. + detached. This is useful for memory saving when only the value is needed. 'last_step' provides + gradients only for the last iteration of the Sinkhorn solver. Returns ------- @@ -280,7 +281,6 @@ def solve( linear regression. NeurIPS. """ - # detect backend nx = get_backend(M, a, b, c) @@ -411,7 +411,9 @@ def solve( potentials = (log["u"], log["v"]) elif reg_type.lower() in ["entropy", "kl"]: - if grad == "envelope": # if envelope then detach the input + if ( + grad == "envelope" or grad == "last_step" + ): # if envelope or last_step then detach the input M0, a0, b0 = M, a, b M, a, b = nx.detach(M, a, b) @@ -420,6 +422,12 @@ def solve( max_iter = 1000 if tol is None: tol = 1e-9 + if grad == "last_step": + if max_iter == 0: + raise ValueError( + "The maximum number of iterations must be greater than 0 when using grad=last_step." + ) + max_iter = max_iter - 1 plan, log = sinkhorn_log( a, @@ -432,6 +440,20 @@ def solve( verbose=verbose, ) + potentials = (log["log_u"], log["log_v"]) + + if grad == "last_step": + plan, log = sinkhorn_log( + a0, + b0, + M0, + reg=reg, + numItermax=1, + stopThr=tol, + log=True, + warmstart=potentials, + ) + value_linear = nx.sum(M * plan) if reg_type.lower() == "entropy": @@ -441,8 +463,6 @@ def solve( plan, a[:, None] * b[None, :] ) - potentials = (log["log_u"], log["log_v"]) - if grad == "envelope": # set the gradient at convergence value = nx.set_gradients( value, diff --git a/test/test_solvers.py b/test/test_solvers.py index 82a402df1..ec6a6c832 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -143,6 +143,78 @@ def test_solve(nx): sol0 = ot.solve(M, reg=1, reg_type="cryptic divergence") +@pytest.mark.skipif(not torch, reason="torch no installed") +def test_solve_last_step(): + n_samples_s = 10 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + M = ot.dist(x, y) + + # Check that for one iteration, autodiff and last_step give the same gradients + a = torch.tensor(a, requires_grad=True) + b = torch.tensor(b, requires_grad=True) + M = torch.tensor(M, requires_grad=True) + + sol0 = ot.solve(M, a, b, max_iter=1, grad="autodiff") + sol0.value.backward() + + gM0 = M.grad.clone() + ga0 = a.grad.clone() + gb0 = b.grad.clone() + + a = torch.tensor(a, requires_grad=True) + b = torch.tensor(b, requires_grad=True) + M = torch.tensor(M, requires_grad=True) + + sol = ot.solve(M, a, b, max_iter=1, grad="last_step") + sol.value.backward() + + gM = M.grad.clone() + ga = a.grad.clone() + gb = b.grad.clone() + + # Note, gradients are invariant to change in constant so we center them + assert torch.allclose(gM0, gM) + assert torch.allclose(ga0 - ga0.mean(), ga - ga.mean()) + assert torch.allclose(gb0 - gb0.mean(), gb - gb.mean()) + + # Check that for multiple iterations, autodiff and last_step give different gradients + a = torch.tensor(a, requires_grad=True) + b = torch.tensor(b, requires_grad=True) + M = torch.tensor(M, requires_grad=True) + + sol0 = ot.solve(M, a, b, reg=1, grad="autodiff") + sol0.value.backward() + + gM0 = M.grad.clone() + ga0 = a.grad.clone() + gb0 = b.grad.clone() + + a = torch.tensor(a, requires_grad=True) + b = torch.tensor(b, requires_grad=True) + M = torch.tensor(M, requires_grad=True) + + sol = ot.solve(M, a, b, reg=1, grad="last_ste") + sol.value.backward() + + print(sol, sol0) + + gM = M.grad.clone() + ga = a.grad.clone() + gb = b.grad.clone() + + # Note, gradients are invariant to change in constant so we center them + assert not torch.allclose(gM0, gM) + assert not torch.allclose(ga0 - ga0.mean(), ga - ga.mean()) + assert not torch.allclose(gb0 - gb0.mean(), gb - gb.mean()) + + @pytest.mark.skipif(not torch, reason="torch no installed") def test_solve_envelope(): n_samples_s = 10 @@ -178,7 +250,7 @@ def test_solve_envelope(): ga = a.grad.clone() gb = b.grad.clone() - # Note, gradients aer invariant to change in constant so we center them + # Note, gradients are invariant to change in constant so we center them assert torch.allclose(gM0, gM) assert torch.allclose(ga0 - ga0.mean(), ga - ga.mean()) assert torch.allclose(gb0 - gb0.mean(), gb - gb.mean()) From d31e759d28cc84b5b61ff6a71fa23941cfd797fb Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Tue, 19 Nov 2024 09:30:27 +0100 Subject: [PATCH 02/16] test --- test/test_solvers.py | 36 ++++-------------------------------- 1 file changed, 4 insertions(+), 32 deletions(-) diff --git a/test/test_solvers.py b/test/test_solvers.py index c93727ae0..782a1520d 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -156,40 +156,12 @@ def test_solve_last_step(): b = ot.utils.unif(n_samples_t) M = ot.dist(x, y) - # Check that for one iteration, autodiff and last_step give the same gradients - a = torch.tensor(a, requires_grad=True) - b = torch.tensor(b, requires_grad=True) - M = torch.tensor(M, requires_grad=True) - - sol0 = ot.solve(M, a, b, max_iter=1, grad="autodiff") - sol0.value.backward() - - gM0 = M.grad.clone() - ga0 = a.grad.clone() - gb0 = b.grad.clone() - - a = torch.tensor(a, requires_grad=True) - b = torch.tensor(b, requires_grad=True) - M = torch.tensor(M, requires_grad=True) - - sol = ot.solve(M, a, b, max_iter=1, grad="last_step") - sol.value.backward() - - gM = M.grad.clone() - ga = a.grad.clone() - gb = b.grad.clone() - - # Note, gradients are invariant to change in constant so we center them - assert torch.allclose(gM0, gM) - assert torch.allclose(ga0 - ga0.mean(), ga - ga.mean()) - assert torch.allclose(gb0 - gb0.mean(), gb - gb.mean()) - # Check that for multiple iterations, autodiff and last_step give different gradients a = torch.tensor(a, requires_grad=True) b = torch.tensor(b, requires_grad=True) M = torch.tensor(M, requires_grad=True) - sol0 = ot.solve(M, a, b, reg=1, grad="autodiff") + sol0 = ot.solve(M, a, b, reg=10, grad="autodiff") sol0.value.backward() gM0 = M.grad.clone() @@ -200,15 +172,15 @@ def test_solve_last_step(): b = torch.tensor(b, requires_grad=True) M = torch.tensor(M, requires_grad=True) - sol = ot.solve(M, a, b, reg=1, grad="last_ste") + sol = ot.solve(M, a, b, reg=10, grad="last_step") sol.value.backward() - print(sol, sol0) - gM = M.grad.clone() ga = a.grad.clone() gb = b.grad.clone() + print(gM, gM0) + # Note, gradients are invariant to change in constant so we center them assert not torch.allclose(gM0, gM) assert not torch.allclose(ga0 - ga0.mean(), ga - ga.mean()) From e822e99faa3132b00b5e663165de260d83b99171 Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Tue, 19 Nov 2024 09:31:26 +0100 Subject: [PATCH 03/16] update test --- test/test_solvers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_solvers.py b/test/test_solvers.py index 782a1520d..05818d9b9 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -179,8 +179,6 @@ def test_solve_last_step(): ga = a.grad.clone() gb = b.grad.clone() - print(gM, gM0) - # Note, gradients are invariant to change in constant so we center them assert not torch.allclose(gM0, gM) assert not torch.allclose(ga0 - ga0.mean(), ga - ga.mean()) From 2111f83236bcb53f7937928f2db556bed6aeeaa1 Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:02:44 +0100 Subject: [PATCH 04/16] Update ot/solvers.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: RĂ©mi Flamary --- ot/solvers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/solvers.py b/ot/solvers.py index 515e2a4be..1800ca6c2 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -413,7 +413,7 @@ def solve( elif reg_type.lower() in ["entropy", "kl"]: if ( - grad == "envelope" or grad == "last_step" + grad in ["envelope", "last_step"] ): # if envelope or last_step then detach the input M0, a0, b0 = M, a, b M, a, b = nx.detach(M, a, b) From 2bf590072905e64bce9f04a61eb547a881e98bec Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:10:40 +0100 Subject: [PATCH 05/16] update doc --- ot/solvers.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 1800ca6c2..8e64a65c4 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -130,7 +130,7 @@ def solve( outputs (`plan, value, value_linear`) but with important memory cost. 'envelope' provides gradients only for `value` and and other outputs are detached. This is useful for memory saving when only the value is needed. 'last_step' provides - gradients only for the last iteration of the Sinkhorn solver. + gradients only for the last iteration of the Sinkhorn solver, but provides gradient for both the OT plan and the objective values. Returns ------- @@ -412,9 +412,10 @@ def solve( potentials = (log["u"], log["v"]) elif reg_type.lower() in ["entropy", "kl"]: - if ( - grad in ["envelope", "last_step"] - ): # if envelope or last_step then detach the input + if grad in [ + "envelope", + "last_step", + ]: # if envelope or last_step then detach the input M0, a0, b0 = M, a, b M, a, b = nx.detach(M, a, b) From cabb1045d836cc89e3927931aa78402c10ba166a Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:12:59 +0100 Subject: [PATCH 06/16] add test for max_iter --- test/test_solvers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_solvers.py b/test/test_solvers.py index 05818d9b9..88970d9eb 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -184,6 +184,9 @@ def test_solve_last_step(): assert not torch.allclose(ga0 - ga0.mean(), ga - ga.mean()) assert not torch.allclose(gb0 - gb0.mean(), gb - gb.mean()) + with pytest.raises(ValueError): + ot.solve(M, a, b, grad="last_step", max_iter=0, reg=10) + @pytest.mark.skipif(not torch, reason="torch no installed") def test_solve_envelope(): From 405b545057fea9494001fcf4af57d189dccc4a5c Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Wed, 20 Nov 2024 10:20:40 +0100 Subject: [PATCH 07/16] fix bug on gradients --- ot/solvers.py | 21 +++++++++++---------- test/test_solvers.py | 15 ++++++++++++--- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 8e64a65c4..d0b584c0f 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -445,16 +445,17 @@ def solve( potentials = (log["log_u"], log["log_v"]) if grad == "last_step": - plan, log = sinkhorn_log( - a0, - b0, - M0, - reg=reg, - numItermax=1, - stopThr=tol, - log=True, - warmstart=potentials, - ) + loga = nx.log(a0) + logb = nx.log(b0) + v = logb - nx.logsumexp(-M0 / reg + potentials[0][:, None], 0) + u = loga - nx.logsumexp(-M0 / reg + potentials[1][None, :], 1) + plan = nx.exp(-M0 / reg + u[:, None] + v[None, :]) + potentials = (u, v) + log["niter"] = max_iter + 1 + log["log_u"] = u + log["log_v"] = v + log["u"] = nx.exp(u) + log["v"] = nx.exp(v) value_linear = nx.sum(M * plan) diff --git a/test/test_solvers.py b/test/test_solvers.py index 88970d9eb..3284c1a0e 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -179,10 +179,19 @@ def test_solve_last_step(): ga = a.grad.clone() gb = b.grad.clone() + cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6) + # Note, gradients are invariant to change in constant so we center them - assert not torch.allclose(gM0, gM) - assert not torch.allclose(ga0 - ga0.mean(), ga - ga.mean()) - assert not torch.allclose(gb0 - gb0.mean(), gb - gb.mean()) + tolerance = 0.96 + assert cos(gM0.flatten(), gM.flatten()) > tolerance + assert cos(ga0 - ga0.mean(), ga - ga.mean()) > tolerance + assert cos(gb0 - gb0.mean(), gb - gb.mean()) > tolerance + + assert torch.allclose(sol0.plan, sol.plan) + assert torch.allclose(sol0.value, sol.value) + assert torch.allclose(sol0.value_linear, sol.value_linear) + assert torch.allclose(sol0.potentials[0], sol.potentials[0]) + assert torch.allclose(sol0.potentials[1], sol.potentials[1]) with pytest.raises(ValueError): ot.solve(M, a, b, grad="last_step", max_iter=0, reg=10) From 73e2e1949ed6cf5aa9f3201ec08c9f79f67f2163 Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Wed, 20 Nov 2024 10:27:12 +0100 Subject: [PATCH 08/16] update RELEASES.md --- RELEASES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASES.md b/RELEASES.md index 1caaa04b4..e74b44507 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -4,6 +4,7 @@ #### New features - Implement CG solvers for partial FGW (PR #687) +- Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) From c334a966b4796564e076831e27c92ae4d24cdfa5 Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Wed, 20 Nov 2024 10:30:35 +0100 Subject: [PATCH 09/16] update comment --- test/test_solvers.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_solvers.py b/test/test_solvers.py index 3284c1a0e..a27f89f2c 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -156,7 +156,7 @@ def test_solve_last_step(): b = ot.utils.unif(n_samples_t) M = ot.dist(x, y) - # Check that for multiple iterations, autodiff and last_step give different gradients + # Check that last_step and autodiff give the same result and similar gradients a = torch.tensor(a, requires_grad=True) b = torch.tensor(b, requires_grad=True) M = torch.tensor(M, requires_grad=True) @@ -179,9 +179,8 @@ def test_solve_last_step(): ga = a.grad.clone() gb = b.grad.clone() - cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6) - # Note, gradients are invariant to change in constant so we center them + cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6) tolerance = 0.96 assert cos(gM0.flatten(), gM.flatten()) > tolerance assert cos(ga0 - ga0.mean(), ga - ga.mean()) > tolerance From 326bcb531c258d258a8de89a84a1c1484d653b35 Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Wed, 20 Nov 2024 16:22:01 +0100 Subject: [PATCH 10/16] add detach and comment --- ot/solvers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ot/solvers.py b/ot/solvers.py index d0b584c0f..80f366354 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -131,6 +131,7 @@ def solve( 'envelope' provides gradients only for `value` and and other outputs are detached. This is useful for memory saving when only the value is needed. 'last_step' provides gradients only for the last iteration of the Sinkhorn solver, but provides gradient for both the OT plan and the objective values. + 'detach' does not compute the gradients for the Sinkhorn solver. Returns ------- @@ -415,7 +416,8 @@ def solve( if grad in [ "envelope", "last_step", - ]: # if envelope or last_step then detach the input + "detach", + ]: # if envelope, last_step or detach then detach the input M0, a0, b0 = M, a, b M, a, b = nx.detach(M, a, b) @@ -444,6 +446,7 @@ def solve( potentials = (log["log_u"], log["log_v"]) + # if last_step, compute the last step of the Sinkhorn algorithm with the non-detached inputs if grad == "last_step": loga = nx.log(a0) logb = nx.log(b0) From 2952331ce866bc0250d4182a2502aa4cbd2d8aea Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Wed, 20 Nov 2024 16:22:29 +0100 Subject: [PATCH 11/16] add example --- examples/plot_Sinkhorn_gradients.py | 87 +++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 examples/plot_Sinkhorn_gradients.py diff --git a/examples/plot_Sinkhorn_gradients.py b/examples/plot_Sinkhorn_gradients.py new file mode 100644 index 000000000..0da45d646 --- /dev/null +++ b/examples/plot_Sinkhorn_gradients.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +""" +================================================ +Different gradient computations for regularized optimal transport +================================================ + +This example illustrates the differences in terms of computation time between the gradient options for the Sinkhorn solver. + +""" + +# Author: Sonia Mazelet +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 4 + +import numpy as np +import matplotlib.pylab as pl +import ot +from ot.datasets import make_1D_gauss as gauss +from ot.backend import torch + + +############################################################################## +# Time comparison of the Sinkhorn solver for different gradient options +# ------------- + + +# %% parameters + +n = 100 # nb bins +n_trials = 500 +times_autodiff = torch.zeros(n_trials) +times_envelope = torch.zeros(n_trials) +times_last_step = torch.zeros(n_trials) + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Time required for the Sinkhorn solver and gradient computations, for different gradient options over multiple Gaussian distributions +for i in range(n_trials): + # Gaussian distributions with random parameters + ma = np.random.randint(10, 40, 2) + sa = np.random.randint(5, 10, 2) + mb = np.random.randint(10, 40) + sb = np.random.randint(5, 10) + + a = 0.6 * gauss(n, m=ma[0], s=sa[0]) + 0.4 * gauss( + n, m=ma[1], s=sa[1] + ) # m= mean, s= std + b = gauss(n, m=mb, s=sb) + + # loss matrix + M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) + M /= M.max() + + a = torch.tensor(a, requires_grad=True) + b = torch.tensor(b, requires_grad=True) + M = torch.tensor(M, requires_grad=True) + + # autodiff provides the gradient for all the outputs (plan, value, value_linear) + ot.tic() + res_autodiff = ot.solve(M, a, b, reg=10, grad="autodiff") + res_autodiff.value.backward() + times_autodiff[i] = ot.toq() + + # envelope provides the gradient for value + ot.tic() + res_envelope = ot.solve(M, a, b, reg=10, grad="envelope") + res_envelope.value.backward() + times_envelope[i] = ot.toq() + + # last_step provides the gradient for all the outputs, but only for the last iteration of the Sinkhorn algorithm + ot.tic() + res_last_step = ot.solve(M, a, b, reg=10, grad="last_step") + res_last_step.value.backward() + times_last_step[i] = ot.toq() + +pl.figure(1, figsize=(4, 3)) +pl.ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) +pl.boxplot( + ([times_autodiff, times_envelope, times_last_step]), + tick_labels=["autodiff", "envelope", "last_step"], + showfliers=False, +) +pl.ylabel("Time (s)") +pl.show() From fba3e2ca2a25ffbecb2a44f488a3f7a53297344b Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Wed, 20 Nov 2024 16:23:00 +0100 Subject: [PATCH 12/16] add test for detach --- test/test_solvers.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/test_solvers.py b/test/test_solvers.py index a27f89f2c..85852aca6 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -196,6 +196,38 @@ def test_solve_last_step(): ot.solve(M, a, b, grad="last_step", max_iter=0, reg=10) +@pytest.mark.skipif(not torch, reason="torch no installed") +def test_solve_detach(): + n_samples_s = 10 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + M = ot.dist(x, y) + + # Check that last_step and autodiff give the same result and similar gradients + a = torch.tensor(a, requires_grad=True) + b = torch.tensor(b, requires_grad=True) + M = torch.tensor(M, requires_grad=True) + + sol0 = ot.solve(M, a, b, reg=10, grad="detach") + + with pytest.raises(RuntimeError): + sol0.value.backward() + + sol = ot.solve(M, a, b, reg=10, grad="autodiff") + + assert torch.allclose(sol0.plan, sol.plan) + assert torch.allclose(sol0.value, sol.value) + assert torch.allclose(sol0.value_linear, sol.value_linear) + assert torch.allclose(sol0.potentials[0], sol.potentials[0]) + assert torch.allclose(sol0.potentials[1], sol.potentials[1]) + + @pytest.mark.skipif(not torch, reason="torch no installed") def test_solve_envelope(): n_samples_s = 10 From d341e67822c112772e1aec7f9d42816d58579661 Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Wed, 20 Nov 2024 17:30:39 +0100 Subject: [PATCH 13/16] fix example --- examples/plot_Sinkhorn_gradients.py | 47 +++++++++++++++-------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/examples/plot_Sinkhorn_gradients.py b/examples/plot_Sinkhorn_gradients.py index 0da45d646..39c547ae2 100644 --- a/examples/plot_Sinkhorn_gradients.py +++ b/examples/plot_Sinkhorn_gradients.py @@ -1,3 +1,4 @@ +# %% # -*- coding: utf-8 -*- """ ================================================ @@ -28,55 +29,55 @@ # %% parameters -n = 100 # nb bins -n_trials = 500 +n_trials = 30 times_autodiff = torch.zeros(n_trials) times_envelope = torch.zeros(n_trials) times_last_step = torch.zeros(n_trials) -# bin positions -x = np.arange(n, dtype=np.float64) +n_samples_s = 300 +n_samples_t = 300 +n_features = 5 +reg = 0.03 # Time required for the Sinkhorn solver and gradient computations, for different gradient options over multiple Gaussian distributions for i in range(n_trials): - # Gaussian distributions with random parameters - ma = np.random.randint(10, 40, 2) - sa = np.random.randint(5, 10, 2) - mb = np.random.randint(10, 40) - sb = np.random.randint(5, 10) - - a = 0.6 * gauss(n, m=ma[0], s=sa[0]) + 0.4 * gauss( - n, m=ma[1], s=sa[1] - ) # m= mean, s= std - b = gauss(n, m=mb, s=sb) - - # loss matrix - M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) - M /= M.max() + x = torch.rand((n_samples_s, n_features)) + y = torch.rand((n_samples_t, n_features)) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + M = ot.dist(x, y) a = torch.tensor(a, requires_grad=True) b = torch.tensor(b, requires_grad=True) - M = torch.tensor(M, requires_grad=True) + M = M.clone().detach().requires_grad_(True) # autodiff provides the gradient for all the outputs (plan, value, value_linear) ot.tic() - res_autodiff = ot.solve(M, a, b, reg=10, grad="autodiff") + res_autodiff = ot.solve(M, a, b, reg=reg, grad="autodiff") res_autodiff.value.backward() times_autodiff[i] = ot.toq() + a = a.clone().detach().requires_grad_(True) + b = b.clone().detach().requires_grad_(True) + M = M.clone().detach().requires_grad_(True) + # envelope provides the gradient for value ot.tic() - res_envelope = ot.solve(M, a, b, reg=10, grad="envelope") + res_envelope = ot.solve(M, a, b, reg=reg, grad="envelope") res_envelope.value.backward() times_envelope[i] = ot.toq() + a = a.clone().detach().requires_grad_(True) + b = b.clone().detach().requires_grad_(True) + M = M.clone().detach().requires_grad_(True) + # last_step provides the gradient for all the outputs, but only for the last iteration of the Sinkhorn algorithm ot.tic() - res_last_step = ot.solve(M, a, b, reg=10, grad="last_step") + res_last_step = ot.solve(M, a, b, reg=reg, grad="last_step") res_last_step.value.backward() times_last_step[i] = ot.toq() -pl.figure(1, figsize=(4, 3)) +pl.figure(1, figsize=(5, 3)) pl.ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) pl.boxplot( ([times_autodiff, times_envelope, times_last_step]), From 048d468fbfaa2a08101597d23aa73d44f192bbee Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Thu, 21 Nov 2024 10:33:09 +0100 Subject: [PATCH 14/16] delete unused importations in example --- examples/plot_Sinkhorn_gradients.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/plot_Sinkhorn_gradients.py b/examples/plot_Sinkhorn_gradients.py index 39c547ae2..457113a73 100644 --- a/examples/plot_Sinkhorn_gradients.py +++ b/examples/plot_Sinkhorn_gradients.py @@ -15,10 +15,8 @@ # sphinx_gallery_thumbnail_number = 4 -import numpy as np import matplotlib.pylab as pl import ot -from ot.datasets import make_1D_gauss as gauss from ot.backend import torch @@ -78,7 +76,7 @@ times_last_step[i] = ot.toq() pl.figure(1, figsize=(5, 3)) -pl.ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) +pl.ticklabel_format(axis="y", style="sci", scilimits=(0, 0)) pl.boxplot( ([times_autodiff, times_envelope, times_last_step]), tick_labels=["autodiff", "envelope", "last_step"], From bf0a65e2151a9cc6415eb1b24aded9e76a6a75a1 Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Thu, 21 Nov 2024 10:39:51 +0100 Subject: [PATCH 15/16] move example to backend --- examples/{ => backends}/plot_Sinkhorn_gradients.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename examples/{ => backends}/plot_Sinkhorn_gradients.py (98%) diff --git a/examples/plot_Sinkhorn_gradients.py b/examples/backends/plot_Sinkhorn_gradients.py similarity index 98% rename from examples/plot_Sinkhorn_gradients.py rename to examples/backends/plot_Sinkhorn_gradients.py index 457113a73..e5e8f1909 100644 --- a/examples/plot_Sinkhorn_gradients.py +++ b/examples/backends/plot_Sinkhorn_gradients.py @@ -13,7 +13,7 @@ # # License: MIT License -# sphinx_gallery_thumbnail_number = 4 +# sphinx_gallery_thumbnail_number = 1 import matplotlib.pylab as pl import ot From 0e985b27f0b0f6148cf00108fa5c3ec539049871 Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Thu, 21 Nov 2024 11:34:28 +0100 Subject: [PATCH 16/16] reduce n_trials for example --- examples/backends/plot_Sinkhorn_gradients.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/backends/plot_Sinkhorn_gradients.py b/examples/backends/plot_Sinkhorn_gradients.py index e5e8f1909..229a952b1 100644 --- a/examples/backends/plot_Sinkhorn_gradients.py +++ b/examples/backends/plot_Sinkhorn_gradients.py @@ -1,4 +1,3 @@ -# %% # -*- coding: utf-8 -*- """ ================================================ @@ -27,7 +26,7 @@ # %% parameters -n_trials = 30 +n_trials = 10 times_autodiff = torch.zeros(n_trials) times_envelope = torch.zeros(n_trials) times_last_step = torch.zeros(n_trials)