diff --git a/ot/solvers.py b/ot/solvers.py index 1800ca6c2..8e64a65c4 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -130,7 +130,7 @@ def solve( outputs (`plan, value, value_linear`) but with important memory cost. 'envelope' provides gradients only for `value` and and other outputs are detached. This is useful for memory saving when only the value is needed. 'last_step' provides - gradients only for the last iteration of the Sinkhorn solver. + gradients only for the last iteration of the Sinkhorn solver, but provides gradient for both the OT plan and the objective values. Returns ------- @@ -412,9 +412,10 @@ def solve( potentials = (log["u"], log["v"]) elif reg_type.lower() in ["entropy", "kl"]: - if ( - grad in ["envelope", "last_step"] - ): # if envelope or last_step then detach the input + if grad in [ + "envelope", + "last_step", + ]: # if envelope or last_step then detach the input M0, a0, b0 = M, a, b M, a, b = nx.detach(M, a, b)