-
Notifications
You must be signed in to change notification settings - Fork 507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Sinkhorn gradient last step #693
Merged
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
adab595
change solver
SoniaMaz8 6a3eaaa
Merge branch 'PythonOT:master' into Sinkhorn_last_step
SoniaMaz8 d31e759
test
SoniaMaz8 e822e99
update test
SoniaMaz8 2111f83
Update ot/solvers.py
SoniaMaz8 2bf5900
update doc
SoniaMaz8 cabb104
add test for max_iter
SoniaMaz8 405b545
fix bug on gradients
SoniaMaz8 73e2e19
update RELEASES.md
SoniaMaz8 c334a96
update comment
SoniaMaz8 326bcb5
add detach and comment
SoniaMaz8 2952331
add example
SoniaMaz8 fba3e2c
add test for detach
SoniaMaz8 d341e67
fix example
SoniaMaz8 048d468
delete unused importations in example
SoniaMaz8 bf0a65e
move example to backend
SoniaMaz8 0e985b2
reduce n_trials for example
SoniaMaz8 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
================================================ | ||
Different gradient computations for regularized optimal transport | ||
================================================ | ||
|
||
This example illustrates the differences in terms of computation time between the gradient options for the Sinkhorn solver. | ||
|
||
""" | ||
|
||
# Author: Sonia Mazelet <sonia.mazelet@polytechnique.edu> | ||
# | ||
# License: MIT License | ||
|
||
# sphinx_gallery_thumbnail_number = 1 | ||
|
||
import matplotlib.pylab as pl | ||
import ot | ||
from ot.backend import torch | ||
|
||
|
||
############################################################################## | ||
# Time comparison of the Sinkhorn solver for different gradient options | ||
# ------------- | ||
|
||
|
||
# %% parameters | ||
|
||
n_trials = 10 | ||
times_autodiff = torch.zeros(n_trials) | ||
times_envelope = torch.zeros(n_trials) | ||
times_last_step = torch.zeros(n_trials) | ||
|
||
n_samples_s = 300 | ||
n_samples_t = 300 | ||
n_features = 5 | ||
reg = 0.03 | ||
|
||
# Time required for the Sinkhorn solver and gradient computations, for different gradient options over multiple Gaussian distributions | ||
for i in range(n_trials): | ||
x = torch.rand((n_samples_s, n_features)) | ||
y = torch.rand((n_samples_t, n_features)) | ||
a = ot.utils.unif(n_samples_s) | ||
b = ot.utils.unif(n_samples_t) | ||
M = ot.dist(x, y) | ||
|
||
a = torch.tensor(a, requires_grad=True) | ||
b = torch.tensor(b, requires_grad=True) | ||
M = M.clone().detach().requires_grad_(True) | ||
|
||
# autodiff provides the gradient for all the outputs (plan, value, value_linear) | ||
ot.tic() | ||
res_autodiff = ot.solve(M, a, b, reg=reg, grad="autodiff") | ||
res_autodiff.value.backward() | ||
times_autodiff[i] = ot.toq() | ||
|
||
a = a.clone().detach().requires_grad_(True) | ||
b = b.clone().detach().requires_grad_(True) | ||
M = M.clone().detach().requires_grad_(True) | ||
|
||
# envelope provides the gradient for value | ||
ot.tic() | ||
res_envelope = ot.solve(M, a, b, reg=reg, grad="envelope") | ||
res_envelope.value.backward() | ||
times_envelope[i] = ot.toq() | ||
|
||
a = a.clone().detach().requires_grad_(True) | ||
b = b.clone().detach().requires_grad_(True) | ||
M = M.clone().detach().requires_grad_(True) | ||
|
||
# last_step provides the gradient for all the outputs, but only for the last iteration of the Sinkhorn algorithm | ||
ot.tic() | ||
res_last_step = ot.solve(M, a, b, reg=reg, grad="last_step") | ||
res_last_step.value.backward() | ||
times_last_step[i] = ot.toq() | ||
|
||
pl.figure(1, figsize=(5, 3)) | ||
pl.ticklabel_format(axis="y", style="sci", scilimits=(0, 0)) | ||
pl.boxplot( | ||
([times_autodiff, times_envelope, times_last_step]), | ||
tick_labels=["autodiff", "envelope", "last_step"], | ||
showfliers=False, | ||
) | ||
pl.ylabel("Time (s)") | ||
pl.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -125,11 +125,13 @@ def solve( | |
verbose : bool, optional | ||
Print information in the solver, by default False | ||
grad : str, optional | ||
Type of gradient computation, either or 'autodiff' or 'envelope' used only for | ||
Type of gradient computation, either or 'autodiff', 'envelope' or 'last_step' used only for | ||
Sinkhorn solver. By default 'autodiff' provides gradients wrt all | ||
outputs (`plan, value, value_linear`) but with important memory cost. | ||
'envelope' provides gradients only for `value` and and other outputs are | ||
detached. This is useful for memory saving when only the value is needed. | ||
detached. This is useful for memory saving when only the value is needed. 'last_step' provides | ||
gradients only for the last iteration of the Sinkhorn solver, but provides gradient for both the OT plan and the objective values. | ||
'detach' does not compute the gradients for the Sinkhorn solver. | ||
|
||
Returns | ||
------- | ||
|
@@ -281,7 +283,6 @@ def solve( | |
linear regression. NeurIPS. | ||
|
||
""" | ||
|
||
# detect backend | ||
nx = get_backend(M, a, b, c) | ||
|
||
|
@@ -412,7 +413,11 @@ def solve( | |
potentials = (log["u"], log["v"]) | ||
|
||
elif reg_type.lower() in ["entropy", "kl"]: | ||
if grad == "envelope": # if envelope then detach the input | ||
if grad in [ | ||
"envelope", | ||
"last_step", | ||
"detach", | ||
]: # if envelope, last_step or detach then detach the input | ||
M0, a0, b0 = M, a, b | ||
M, a, b = nx.detach(M, a, b) | ||
|
||
|
@@ -421,6 +426,12 @@ def solve( | |
max_iter = 1000 | ||
if tol is None: | ||
tol = 1e-9 | ||
if grad == "last_step": | ||
if max_iter == 0: | ||
rflamary marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError( | ||
"The maximum number of iterations must be greater than 0 when using grad=last_step." | ||
) | ||
max_iter = max_iter - 1 | ||
|
||
plan, log = sinkhorn_log( | ||
a, | ||
|
@@ -433,6 +444,22 @@ def solve( | |
verbose=verbose, | ||
) | ||
|
||
potentials = (log["log_u"], log["log_v"]) | ||
|
||
# if last_step, compute the last step of the Sinkhorn algorithm with the non-detached inputs | ||
if grad == "last_step": | ||
loga = nx.log(a0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe a few comments do dicuss the sinkhorn iettaration here |
||
logb = nx.log(b0) | ||
v = logb - nx.logsumexp(-M0 / reg + potentials[0][:, None], 0) | ||
u = loga - nx.logsumexp(-M0 / reg + potentials[1][None, :], 1) | ||
plan = nx.exp(-M0 / reg + u[:, None] + v[None, :]) | ||
potentials = (u, v) | ||
log["niter"] = max_iter + 1 | ||
log["log_u"] = u | ||
log["log_v"] = v | ||
log["u"] = nx.exp(u) | ||
log["v"] = nx.exp(v) | ||
|
||
value_linear = nx.sum(M * plan) | ||
|
||
if reg_type.lower() == "entropy": | ||
|
@@ -442,8 +469,6 @@ def solve( | |
plan, a[:, None] * b[None, :] | ||
) | ||
|
||
potentials = (log["log_u"], log["log_v"]) | ||
|
||
if grad == "envelope": # set the gradient at convergence | ||
value = nx.set_gradients( | ||
value, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you also add "detach" option here and in teh documengation? seems like a nice feature that can be practical