-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Sinkhorn gradient last step #693
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #693 +/- ##
==========================================
+ Coverage 97.05% 97.06% +0.01%
==========================================
Files 98 98
Lines 19877 19955 +78
==========================================
+ Hits 19292 19370 +78
Misses 585 585
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is nice, a few comments
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very nice, a feww remaining comments
potentials = (log["log_u"], log["log_v"]) | ||
|
||
if grad == "last_step": | ||
loga = nx.log(a0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe a few comments do dicuss the sinkhorn iettaration here
@@ -412,7 +412,10 @@ def solve( | |||
potentials = (log["u"], log["v"]) | |||
|
|||
elif reg_type.lower() in ["entropy", "kl"]: | |||
if grad == "envelope": # if envelope then detach the input | |||
if grad in [ | |||
"envelope", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you also add "detach" option here and in teh documengation? seems like a nice feature that can be practical
Types of changes
I added a
grad=last_step
feature to theot.solvers.solve
function to limit the gradient computation to the last iteration of Sinkhorn instead of computing the gradient for every iteration.Motivation and context / Related issue
This change is required in the case where computing the gradient for all the Sinkhorn iterations is too heavy.
How has this been tested (if it applies)
I created a
test_solve_last_step
test to check that the gradient is different when usinggrad=last_step
andgrad=autodiff
.PR checklist