Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Nearest Brenier Potentials #526

Merged
merged 25 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f2c65a8
added functions to a new mapping module
eloitanguy Sep 20, 2023
284c004
simplify ssnb function structure
eloitanguy Sep 20, 2023
a2d9ae8
RELEASES.md conflict fix
eloitanguy Sep 20, 2023
cd7be18
SSNB example
eloitanguy Sep 20, 2023
c82e0e6
removed numpy saves from example for prod
eloitanguy Sep 20, 2023
79a99ef
tests apart from the import exception catch
eloitanguy Sep 20, 2023
801c280
tests apart from the import exception catch
eloitanguy Sep 20, 2023
ff46975
da class and tests
eloitanguy Sep 21, 2023
60944f5
guessed PR number
eloitanguy Sep 21, 2023
d0d42be
Merge remote-tracking branch 'origin/master' into contrib_ssnb
eloitanguy Sep 21, 2023
7bc3213
removed unused import
eloitanguy Sep 21, 2023
55f0e09
PEP8 tab errors fix
eloitanguy Sep 21, 2023
9dfd82e
skip ssnb test if no cvxpy
eloitanguy Sep 21, 2023
0489392
test and doc fixes
eloitanguy Sep 21, 2023
2adcab3
doc dependency + minor comment in ot __init__.py
eloitanguy Sep 21, 2023
945554e
fetch ot main diffsh
eloitanguy Sep 21, 2023
3e2e5b8
PEP8 fixes
eloitanguy Sep 21, 2023
596edd4
test typo fix
eloitanguy Sep 21, 2023
80fa0b9
ssnb da backend test fix
eloitanguy Sep 21, 2023
0a349ce
moved joint ot mappings to the mapping module
eloitanguy Sep 21, 2023
66e484b
merge with pythonot master
eloitanguy Sep 26, 2023
8eb1542
better ssnb example + ssnb initilisation + small joint_ot_mapping tests
eloitanguy Sep 27, 2023
f29dcff
better ssnb example + ssnb initilisation + small joint_ot_mapping tests
eloitanguy Sep 27, 2023
7a5e6d7
removed unused dependency in example
eloitanguy Sep 27, 2023
55a9b28
no longer import mapping in __init__ + example thumbnail fix + made q…
eloitanguy Oct 11, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ POT provides the following generic OT solvers (links to examples):
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) (exact and regularized [48]).
* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50].
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
* Smooth Strongly Convex Nearest Brenier Potentials [58], with an extension to bounding potentials using [59].

POT provides the following Machine Learning related solvers:

Expand Down Expand Up @@ -329,3 +330,7 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
distances between Gaussian distributions](https://hal.science/hal-03197398v2/file/main.pdf). Journal of Applied Probability, 59(4),
1178-1198.

[58] Paty F-P., d’Aspremont 1., & Cuturi M. (2020). [Regularity as regularization:Smooth and strongly convex brenier potentials in optimal transport.](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) In International Conference on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020.

[59] Taylor A. B. (2017). [Convex interpolation and performance estimation of first-order methods for convex optimization.](https://dial.uclouvain.be/pr/boreal/object/boreal%3A182881/datastream/PDF_01/view) PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, 2017.

3 changes: 2 additions & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
## 0.9.2dev

#### New features
+ Tweaked `get_backend` to ignore `None` inputs (PR # 525)
+ Added support for [Nearest Brenier Potentials (SSNB)](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) (PR #526)
+ Tweaked `get_backend` to ignore `None` inputs (PR #525)
+ Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507)

#### Closed issues
Expand Down
3 changes: 2 additions & 1 deletion docs/requirements_rtd.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ matplotlib
autograd
pymanopt
cvxopt
scikit-learn
scikit-learn
cvxpy
1 change: 1 addition & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ API and modules
gnn
gromov
lp
mapping
optim
partial
plot
Expand Down
127 changes: 127 additions & 0 deletions examples/others/plot_SSNB.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# -*- coding: utf-8 -*-
# sphinx_gallery_thumbnail_number = 4
r"""
=====================================================
Smooth and Strongly Convex Nearest Brenier Potentials
=====================================================

This example is designed to show how to use SSNB [58] in POT.
SSNB computes an l-strongly convex potential :math:`\varphi` with an L-Lipschitz gradient such that
:math:`\nabla \varphi \# \mu \approx \nu`. This regularity can be enforced only on the components of a partition
of the ambient space, which is a relaxation compared to imposing global regularity.

In this example, we consider a source measure :math:`\mu_s` which is the uniform measure on the unit square in
:math:`\mathbb{R}^2`, and the target measure :math:`\mu_t` which is the image of :math:`\mu_x` by
:math:`T(x_1, x_2) = (x_1 + 2\mathrm{sign}(x_2), 2 * x_2)`. The map :math:`T` is non-smooth, and we wish to approximate
it using a "Brenier-style" map :math:`\nabla \varphi` which is regular on the partition
:math:`\lbrace x_1 <=0, x_1>0\rbrace`, which is well adapted to this particular dataset.

We represent the gradients of the "bounding potentials" :math:`\varphi_l, \varphi_u` (from [59], Theorem 3.14),
which bound any SSNB potential which is optimal in the sense of [58], Definition 1:

.. math::
\varphi \in \mathrm{argmin}_{\varphi \in \mathcal{F}}\ \mathrm{W}_2(\nabla \varphi \#\mu_s, \mu_t),

where :math:`\mathcal{F}` is the space functions that are on every set :math:`E_k` l-strongly convex
with an L-Lipschitz gradient, given :math:`(E_k)_{k \in [K]}` a partition of the ambient source space.

We perform the optimisation on a low amount of fitting samples and with few iterations,
since solving the SSNB problem is quite computationally expensive.

THIS EXAMPLE REQUIRES CVXPY

.. [58] François-Pierre Paty, Alexandre d’Aspremont, and Marco Cuturi. Regularity as regularization:
Smooth and strongly convex brenier potentials in optimal transport. In International Conference
on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020.

.. [59] Adrien B Taylor. Convex interpolation and performance estimation of first-order methods for
convex optimization. PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium,
2017.
"""

# Author: Eloi Tanguy <eloi.tanguy@u-paris.fr>
# License: MIT License

import matplotlib.pyplot as plt
import numpy as np
import ot

# %%
# Generating the fitting data
n_fitting_samples = 30
rng = np.random.RandomState(seed=0)
Xs = rng.uniform(-1, 1, size=(n_fitting_samples, 2))
Xs_classes = (Xs[:, 0] < 0).astype(int)
Xt = np.stack([Xs[:, 0] + 2 * np.sign(Xs[:, 0]), 2 * Xs[:, 1]], axis=-1)

plt.scatter(Xs[Xs_classes == 0, 0], Xs[Xs_classes == 0, 1], c='blue', label='source class 0')
plt.scatter(Xs[Xs_classes == 1, 0], Xs[Xs_classes == 1, 1], c='dodgerblue', label='source class 1')
plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target')
plt.axis('equal')
plt.title('Splitting sphere dataset')
plt.legend(loc='upper right')
plt.show()

# %%
# Plotting image of barycentric projection (SSNB initialisation values)
plt.clf()
pi = ot.emd(ot.unif(n_fitting_samples), ot.unif(n_fitting_samples), ot.dist(Xs, Xt))
plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source')
plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target')
bar_img = pi @ Xt
for i in range(n_fitting_samples):
plt.plot([Xs[i, 0], bar_img[i, 0]], [Xs[i, 1], bar_img[i, 1]], color='black', alpha=.5)
plt.title('Images of in-data source samples by the barycentric map')
plt.legend(loc='upper right')
plt.axis('equal')
plt.show()

# %%
# Fitting the Nearest Brenier Potential
L = 3 # need L > 2 to allow the 2*y term, default is 1.4
phi, G = ot.nearest_brenier_potential_fit(Xs, Xt, Xs_classes, its=10, init_method='barycentric',
gradient_lipschitz_constant=L)

# %%
# Plotting the images of the source data
plt.clf()
plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source')
plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target')
for i in range(n_fitting_samples):
plt.plot([Xs[i, 0], G[i, 0]], [Xs[i, 1], G[i, 1]], color='black', alpha=.5)
plt.title('Images of in-data source samples by the fitted SSNB')
plt.legend(loc='upper right')
plt.axis('equal')
plt.show()

# %%
# Computing the predictions (images by nabla phi) for random samples of the source distribution
n_predict_samples = 50
Ys = rng.uniform(-1, 1, size=(n_predict_samples, 2))
Ys_classes = (Ys[:, 0] < 0).astype(int)
phi_lu, G_lu = ot.nearest_brenier_potential_predict_bounds(Xs, phi, G, Ys, Xs_classes, Ys_classes,
gradient_lipschitz_constant=L)

# %%
# Plot predictions for the gradient of the lower-bounding potential
plt.clf()
plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source')
plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target')
for i in range(n_predict_samples):
plt.plot([Ys[i, 0], G_lu[0, i, 0]], [Ys[i, 1], G_lu[0, i, 1]], color='black', alpha=.5)
plt.title('Images of new source samples by $\\nabla \\varphi_l$')
plt.legend(loc='upper right')
plt.axis('equal')
plt.show()

# %%
# Plot predictions for the gradient of the upper-bounding potential
plt.clf()
plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source')
plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target')
for i in range(n_predict_samples):
plt.plot([Ys[i, 0], G_lu[1, i, 0]], [Ys[i, 1], G_lu[1, i, 1]], color='black', alpha=.5)
plt.title('Images of new source samples by $\\nabla \\varphi_u$')
plt.legend(loc='upper right')
plt.axis('equal')
plt.show()
26 changes: 15 additions & 11 deletions ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
:py:mod:`ot.utils`, :py:mod:`ot.datasets`,
:py:mod:`ot.gromov`, :py:mod:`ot.smooth`
:py:mod:`ot.stochastic`, :py:mod:`ot.partial`, :py:mod:`ot.regpath`
, :py:mod:`ot.unbalanced`.
, :py:mod:`ot.unbalanced`, :py:mod`ot.mapping`.
The following sub-modules are not imported due to additional dependencies:
- :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`.
- :any:`ot.plot` : depends on :code:`matplotlib`
Expand Down Expand Up @@ -35,37 +35,41 @@
from . import factored
from . import solvers
from . import gaussian
from . import mapping

# OT functions
from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d,
binary_search_circle, wasserstein_circle,
semidiscrete_wasserstein2_unif_circle)
from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d,
binary_search_circle, wasserstein_circle,
semidiscrete_wasserstein2_unif_circle)
from .bregman import sinkhorn, sinkhorn2, barycenter
from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced,
sinkhorn_unbalanced2)
from .da import sinkhorn_lpl1_mm
from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance,
from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance,
sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif)
from .gromov import (gromov_wasserstein, gromov_wasserstein2,
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
from .weak import weak_optimal_transport
from .factored import factored_optimal_transport
from .solvers import solve
from .mapping import (nearest_brenier_potential_fit, nearest_brenier_potential_predict_bounds, joint_OT_mapping_kernel,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably not necessary to load them here, those are pretty specific functions (we usually only load generic functions or some that were there before)), same comment below on the lists in text

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @rflamary!, The last commit addresses your comment, and fixes the example icon. I also made the function ot.mapping._ssnb_qcqp_constants private (added the _), since a user would have no use for this helper function on its own.

joint_OT_mapping_linear)

# utils functions
from .utils import dist, unif, tic, toc, toq

__version__ = "0.9.1"
__version__ = "0.9.2dev"

__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
'max_sliced_wasserstein_distance', 'weak_optimal_transport',
'factored_optimal_transport', 'solve',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein',
'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport',
'factored_optimal_transport', 'solve',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
'binary_search_circle', 'wasserstein_circle',
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif']
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'nearest_brenier_potential_fit',
'nearest_brenier_potential_predict_bounds', 'joint_OT_mapping_kernel', 'joint_OT_mapping_linear']
Loading