From 2bf590072905e64bce9f04a61eb547a881e98bec Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:10:40 +0100 Subject: [PATCH] update doc --- ot/solvers.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 1800ca6c2..8e64a65c4 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -130,7 +130,7 @@ def solve( outputs (`plan, value, value_linear`) but with important memory cost. 'envelope' provides gradients only for `value` and and other outputs are detached. This is useful for memory saving when only the value is needed. 'last_step' provides - gradients only for the last iteration of the Sinkhorn solver. + gradients only for the last iteration of the Sinkhorn solver, but provides gradient for both the OT plan and the objective values. Returns ------- @@ -412,9 +412,10 @@ def solve( potentials = (log["u"], log["v"]) elif reg_type.lower() in ["entropy", "kl"]: - if ( - grad in ["envelope", "last_step"] - ): # if envelope or last_step then detach the input + if grad in [ + "envelope", + "last_step", + ]: # if envelope or last_step then detach the input M0, a0, b0 = M, a, b M, a, b = nx.detach(M, a, b)