Skip to content

Commit

Permalink
update doc
Browse files Browse the repository at this point in the history
  • Loading branch information
SoniaMaz8 committed Nov 19, 2024
1 parent 2111f83 commit 2bf5900
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions ot/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def solve(
outputs (`plan, value, value_linear`) but with important memory cost.
'envelope' provides gradients only for `value` and and other outputs are
detached. This is useful for memory saving when only the value is needed. 'last_step' provides
gradients only for the last iteration of the Sinkhorn solver.
gradients only for the last iteration of the Sinkhorn solver, but provides gradient for both the OT plan and the objective values.
Returns
-------
Expand Down Expand Up @@ -412,9 +412,10 @@ def solve(
potentials = (log["u"], log["v"])

elif reg_type.lower() in ["entropy", "kl"]:
if (
grad in ["envelope", "last_step"]
): # if envelope or last_step then detach the input
if grad in [
"envelope",
"last_step",
]: # if envelope or last_step then detach the input
M0, a0, b0 = M, a, b
M, a, b = nx.detach(M, a, b)

Expand Down

0 comments on commit 2bf5900

Please sign in to comment.