Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Sinkhorn gradient last step #693

Merged
merged 17 commits into from
Nov 21, 2024
32 changes: 26 additions & 6 deletions ot/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,12 @@
verbose : bool, optional
Print information in the solver, by default False
grad : str, optional
Type of gradient computation, either or 'autodiff' or 'envelope' used only for
Type of gradient computation, either or 'autodiff', 'envelope' or 'last_step' used only for
Sinkhorn solver. By default 'autodiff' provides gradients wrt all
outputs (`plan, value, value_linear`) but with important memory cost.
'envelope' provides gradients only for `value` and and other outputs are
detached. This is useful for memory saving when only the value is needed.
detached. This is useful for memory saving when only the value is needed. 'last_step' provides
gradients only for the last iteration of the Sinkhorn solver.
rflamary marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
Expand Down Expand Up @@ -281,7 +282,6 @@
linear regression. NeurIPS.

"""

# detect backend
nx = get_backend(M, a, b, c)

Expand Down Expand Up @@ -412,7 +412,9 @@
potentials = (log["u"], log["v"])

elif reg_type.lower() in ["entropy", "kl"]:
if grad == "envelope": # if envelope then detach the input
if (
grad == "envelope" or grad == "last_step"
SoniaMaz8 marked this conversation as resolved.
Show resolved Hide resolved
): # if envelope or last_step then detach the input
M0, a0, b0 = M, a, b
M, a, b = nx.detach(M, a, b)

Expand All @@ -421,6 +423,12 @@
max_iter = 1000
if tol is None:
tol = 1e-9
if grad == "last_step":
if max_iter == 0:
rflamary marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(

Check warning on line 428 in ot/solvers.py

View check run for this annotation

Codecov / codecov/patch

ot/solvers.py#L428

Added line #L428 was not covered by tests
"The maximum number of iterations must be greater than 0 when using grad=last_step."
)
max_iter = max_iter - 1

plan, log = sinkhorn_log(
a,
Expand All @@ -433,6 +441,20 @@
verbose=verbose,
)

potentials = (log["log_u"], log["log_v"])

if grad == "last_step":
plan, log = sinkhorn_log(
a0,
b0,
M0,
reg=reg,
numItermax=1,
stopThr=tol,
log=True,
warmstart=potentials,
)

value_linear = nx.sum(M * plan)

if reg_type.lower() == "entropy":
Expand All @@ -442,8 +464,6 @@
plan, a[:, None] * b[None, :]
)

potentials = (log["log_u"], log["log_v"])

if grad == "envelope": # set the gradient at convergence
value = nx.set_gradients(
value,
Expand Down
44 changes: 43 additions & 1 deletion test/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,48 @@ def test_solve(nx):
sol0 = ot.solve(M, reg=1, reg_type="cryptic divergence")


@pytest.mark.skipif(not torch, reason="torch no installed")
def test_solve_last_step():
n_samples_s = 10
n_samples_t = 7
n_features = 2
rng = np.random.RandomState(0)

x = rng.randn(n_samples_s, n_features)
y = rng.randn(n_samples_t, n_features)
a = ot.utils.unif(n_samples_s)
b = ot.utils.unif(n_samples_t)
M = ot.dist(x, y)

# Check that for multiple iterations, autodiff and last_step give different gradients
a = torch.tensor(a, requires_grad=True)
b = torch.tensor(b, requires_grad=True)
M = torch.tensor(M, requires_grad=True)

sol0 = ot.solve(M, a, b, reg=10, grad="autodiff")
sol0.value.backward()

gM0 = M.grad.clone()
ga0 = a.grad.clone()
gb0 = b.grad.clone()

a = torch.tensor(a, requires_grad=True)
b = torch.tensor(b, requires_grad=True)
M = torch.tensor(M, requires_grad=True)

sol = ot.solve(M, a, b, reg=10, grad="last_step")
sol.value.backward()

gM = M.grad.clone()
ga = a.grad.clone()
gb = b.grad.clone()

# Note, gradients are invariant to change in constant so we center them
assert not torch.allclose(gM0, gM)
assert not torch.allclose(ga0 - ga0.mean(), ga - ga.mean())
assert not torch.allclose(gb0 - gb0.mean(), gb - gb.mean())


@pytest.mark.skipif(not torch, reason="torch no installed")
def test_solve_envelope():
n_samples_s = 10
Expand Down Expand Up @@ -178,7 +220,7 @@ def test_solve_envelope():
ga = a.grad.clone()
gb = b.grad.clone()

# Note, gradients aer invariant to change in constant so we center them
# Note, gradients are invariant to change in constant so we center them
assert torch.allclose(gM0, gM)
assert torch.allclose(ga0 - ga0.mean(), ga - ga.mean())
assert torch.allclose(gb0 - gb0.mean(), gb - gb.mean())
Expand Down
Loading