-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Nearest Brenier Potentials #526
Merged
Merged
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
f2c65a8
added functions to a new mapping module
eloitanguy 284c004
simplify ssnb function structure
eloitanguy a2d9ae8
RELEASES.md conflict fix
eloitanguy cd7be18
SSNB example
eloitanguy c82e0e6
removed numpy saves from example for prod
eloitanguy 79a99ef
tests apart from the import exception catch
eloitanguy 801c280
tests apart from the import exception catch
eloitanguy ff46975
da class and tests
eloitanguy 60944f5
guessed PR number
eloitanguy d0d42be
Merge remote-tracking branch 'origin/master' into contrib_ssnb
eloitanguy 7bc3213
removed unused import
eloitanguy 55f0e09
PEP8 tab errors fix
eloitanguy 9dfd82e
skip ssnb test if no cvxpy
eloitanguy 0489392
test and doc fixes
eloitanguy 2adcab3
doc dependency + minor comment in ot __init__.py
eloitanguy 945554e
fetch ot main diffsh
eloitanguy 3e2e5b8
PEP8 fixes
eloitanguy 596edd4
test typo fix
eloitanguy 80fa0b9
ssnb da backend test fix
eloitanguy 0a349ce
moved joint ot mappings to the mapping module
eloitanguy 66e484b
merge with pythonot master
eloitanguy 8eb1542
better ssnb example + ssnb initilisation + small joint_ot_mapping tests
eloitanguy f29dcff
better ssnb example + ssnb initilisation + small joint_ot_mapping tests
eloitanguy 7a5e6d7
removed unused dependency in example
eloitanguy 55a9b28
no longer import mapping in __init__ + example thumbnail fix + made q…
eloitanguy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,4 +11,5 @@ matplotlib | |
autograd | ||
pymanopt | ||
cvxopt | ||
scikit-learn | ||
scikit-learn | ||
cvxpy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ API and modules | |
gnn | ||
gromov | ||
lp | ||
mapping | ||
optim | ||
partial | ||
plot | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# -*- coding: utf-8 -*- | ||
# sphinx_gallery_thumbnail_number = 4 | ||
r""" | ||
===================================================== | ||
Smooth and Strongly Convex Nearest Brenier Potentials | ||
===================================================== | ||
|
||
This example is designed to show how to use SSNB [58] in POT. | ||
SSNB computes an l-strongly convex potential :math:`\varphi` with an L-Lipschitz gradient such that | ||
:math:`\nabla \varphi \# \mu \approx \nu`. This regularity can be enforced only on the components of a partition | ||
of the ambient space, which is a relaxation compared to imposing global regularity. | ||
|
||
In this example, we consider a source measure :math:`\mu_s` which is the uniform measure on the unit square in | ||
:math:`\mathbb{R}^2`, and the target measure :math:`\mu_t` which is the image of :math:`\mu_x` by | ||
:math:`T(x_1, x_2) = (x_1 + 2\mathrm{sign}(x_2), 2 * x_2)`. The map :math:`T` is non-smooth, and we wish to approximate | ||
it using a "Brenier-style" map :math:`\nabla \varphi` which is regular on the partition | ||
:math:`\lbrace x_1 <=0, x_1>0\rbrace`, which is well adapted to this particular dataset. | ||
|
||
We represent the gradients of the "bounding potentials" :math:`\varphi_l, \varphi_u` (from [59], Theorem 3.14), | ||
which bound any SSNB potential which is optimal in the sense of [58], Definition 1: | ||
|
||
.. math:: | ||
\varphi \in \mathrm{argmin}_{\varphi \in \mathcal{F}}\ \mathrm{W}_2(\nabla \varphi \#\mu_s, \mu_t), | ||
|
||
where :math:`\mathcal{F}` is the space functions that are on every set :math:`E_k` l-strongly convex | ||
with an L-Lipschitz gradient, given :math:`(E_k)_{k \in [K]}` a partition of the ambient source space. | ||
|
||
We perform the optimisation on a low amount of fitting samples and with few iterations, | ||
since solving the SSNB problem is quite computationally expensive. | ||
|
||
THIS EXAMPLE REQUIRES CVXPY | ||
|
||
.. [58] François-Pierre Paty, Alexandre d’Aspremont, and Marco Cuturi. Regularity as regularization: | ||
Smooth and strongly convex brenier potentials in optimal transport. In International Conference | ||
on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020. | ||
|
||
.. [59] Adrien B Taylor. Convex interpolation and performance estimation of first-order methods for | ||
convex optimization. PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, | ||
2017. | ||
""" | ||
|
||
# Author: Eloi Tanguy <eloi.tanguy@u-paris.fr> | ||
# License: MIT License | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import ot | ||
|
||
# %% | ||
# Generating the fitting data | ||
n_fitting_samples = 30 | ||
rng = np.random.RandomState(seed=0) | ||
Xs = rng.uniform(-1, 1, size=(n_fitting_samples, 2)) | ||
Xs_classes = (Xs[:, 0] < 0).astype(int) | ||
Xt = np.stack([Xs[:, 0] + 2 * np.sign(Xs[:, 0]), 2 * Xs[:, 1]], axis=-1) | ||
|
||
plt.scatter(Xs[Xs_classes == 0, 0], Xs[Xs_classes == 0, 1], c='blue', label='source class 0') | ||
plt.scatter(Xs[Xs_classes == 1, 0], Xs[Xs_classes == 1, 1], c='dodgerblue', label='source class 1') | ||
plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target') | ||
plt.axis('equal') | ||
plt.title('Splitting sphere dataset') | ||
plt.legend(loc='upper right') | ||
plt.show() | ||
|
||
# %% | ||
# Plotting image of barycentric projection (SSNB initialisation values) | ||
plt.clf() | ||
pi = ot.emd(ot.unif(n_fitting_samples), ot.unif(n_fitting_samples), ot.dist(Xs, Xt)) | ||
plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source') | ||
plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target') | ||
bar_img = pi @ Xt | ||
for i in range(n_fitting_samples): | ||
plt.plot([Xs[i, 0], bar_img[i, 0]], [Xs[i, 1], bar_img[i, 1]], color='black', alpha=.5) | ||
plt.title('Images of in-data source samples by the barycentric map') | ||
plt.legend(loc='upper right') | ||
plt.axis('equal') | ||
plt.show() | ||
|
||
# %% | ||
# Fitting the Nearest Brenier Potential | ||
L = 3 # need L > 2 to allow the 2*y term, default is 1.4 | ||
phi, G = ot.nearest_brenier_potential_fit(Xs, Xt, Xs_classes, its=10, init_method='barycentric', | ||
gradient_lipschitz_constant=L) | ||
|
||
# %% | ||
# Plotting the images of the source data | ||
plt.clf() | ||
plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source') | ||
plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target') | ||
for i in range(n_fitting_samples): | ||
plt.plot([Xs[i, 0], G[i, 0]], [Xs[i, 1], G[i, 1]], color='black', alpha=.5) | ||
plt.title('Images of in-data source samples by the fitted SSNB') | ||
plt.legend(loc='upper right') | ||
plt.axis('equal') | ||
plt.show() | ||
|
||
# %% | ||
# Computing the predictions (images by nabla phi) for random samples of the source distribution | ||
n_predict_samples = 50 | ||
Ys = rng.uniform(-1, 1, size=(n_predict_samples, 2)) | ||
Ys_classes = (Ys[:, 0] < 0).astype(int) | ||
phi_lu, G_lu = ot.nearest_brenier_potential_predict_bounds(Xs, phi, G, Ys, Xs_classes, Ys_classes, | ||
gradient_lipschitz_constant=L) | ||
|
||
# %% | ||
# Plot predictions for the gradient of the lower-bounding potential | ||
plt.clf() | ||
plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source') | ||
plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target') | ||
for i in range(n_predict_samples): | ||
plt.plot([Ys[i, 0], G_lu[0, i, 0]], [Ys[i, 1], G_lu[0, i, 1]], color='black', alpha=.5) | ||
plt.title('Images of new source samples by $\\nabla \\varphi_l$') | ||
plt.legend(loc='upper right') | ||
plt.axis('equal') | ||
plt.show() | ||
|
||
# %% | ||
# Plot predictions for the gradient of the upper-bounding potential | ||
plt.clf() | ||
plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source') | ||
plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target') | ||
for i in range(n_predict_samples): | ||
plt.plot([Ys[i, 0], G_lu[1, i, 0]], [Ys[i, 1], G_lu[1, i, 1]], color='black', alpha=.5) | ||
plt.title('Images of new source samples by $\\nabla \\varphi_u$') | ||
plt.legend(loc='upper right') | ||
plt.axis('equal') | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably not necessary to load them here, those are pretty specific functions (we usually only load generic functions or some that were there before)), same comment below on the lists in text
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review @rflamary!, The last commit addresses your comment, and fixes the example icon. I also made the function
ot.mapping._ssnb_qcqp_constants
private (added the_
), since a user would have no use for this helper function on its own.