From f2c65a88238e8d1354aa5702fbeb9d05dceb703b Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 20 Sep 2023 13:55:59 +0200 Subject: [PATCH 01/21] added functions to a new mapping module --- README.md | 5 + ot/__init__.py | 1 + ot/mapping.py | 384 +++++++++++++++++++++++++++++++++++++++++++++++ ot/sliced.py | 2 +- requirements.txt | 3 +- 5 files changed, 393 insertions(+), 2 deletions(-) create mode 100644 ot/mapping.py diff --git a/README.md b/README.md index d8ccef82f..f622c5aab 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,7 @@ POT provides the following generic OT solvers (links to examples): * [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) (exact and regularized [48]). * [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. +* Smooth Strongly Convex Nearest Brenier Potentials [58], with an extension to bounding potentials using [59]. POT provides the following Machine Learning related solvers: @@ -329,3 +330,7 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer distances between Gaussian distributions](https://hal.science/hal-03197398v2/file/main.pdf). Journal of Applied Probability, 59(4), 1178-1198. +[58] Paty F-P., d’Aspremont 1., & Cuturi M. (2020). [Regularity as regularization:Smooth and strongly convex brenier potentials in optimal transport.](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) In International Conference on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020. + +[59] Taylor A. B. (2017). [Convex interpolation and performance estimation of first-order methods for convex optimization.](https://dial.uclouvain.be/pr/boreal/object/boreal%3A182881/datastream/PDF_01/view) PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, 2017. + diff --git a/ot/__init__.py b/ot/__init__.py index a0846428f..8e8c6db6a 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -35,6 +35,7 @@ from . import factored from . import solvers from . import gaussian +from . import mapping # OT functions from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, diff --git a/ot/mapping.py b/ot/mapping.py new file mode 100644 index 000000000..05ae5f1bc --- /dev/null +++ b/ot/mapping.py @@ -0,0 +1,384 @@ +# -*- coding: utf-8 -*- +""" +Optimal Transport maps and variants +""" + +# Author: Eloi Tanguy +# +# License: MIT License + +from .backend import get_backend +from .lp import emd +import numpy as np +from .utils import dist, unif + + +def nearest_brenier_potential(X, V, X_classes=None, a=None, b=None, strongly_convex_constant=.6, + gradient_lipschitz_constant=1.4, its=100, log=False, seed=None): + r""" + Computes optimal values and gradients at X for a strongly convex potential :math:`\\varphi` with Lipschitz gradients + on the partitions defined by `X_classes`, where :math:`\\varphi` is optimal such that + :math:`\\nabla \\varphi \#\\mu \\approx \\nu`, given samples :math:`X = x_1, \\cdots, x_n \\sim \\mu` and + :math:`V = v_1, \\cdots, v_n \\sim \\nu`. Finding such a potential that has the desired regularity on the + partition :math:`(E_k)_{k \in [K]}` (given by the classes `X_classes`) is equivalent to finding optimal values + `phi` for the :math:`\\varphi(x_i)` and its gradients :math:`\\nabla \\varphi(x_i)` (variable`G`). + In practice, these optimal values are found by solving the following problem + + .. math:: + \\text{min} \\sum_{i,j}\\pi_{i,j}\|g_i - v_j\|_2^2 + + g_1,\\cdots, g_n \in \mathbb{R}^d,\; \\varphi_1, \\cdots, \\varphi_n \in \mathbb{R},\; \pi \in \Pi(a, b) + + \\text{s.t.}\ \\forall k \in [K],\; \\forall i,j \in I_k: + + \\varphi_i-\\varphi_j-\langle g_j, x_i-x_j\\rangle \geq c_1\|g_i - g_j\|_2^2 + + c_2\|x_i-x_j\|_2^2 - c_3\langle g_j-g_i, x_j -x_i \\rangle. + + The constants :math:`c_1, c_2, c_3` only depend on `strongly_convex_constant` and `gradient_lipschitz_constant`. + The constraint :math:`\pi \in \Pi(a, b)` denotes the fact that the matrix :math:`\pi` belong to the OT polytope + of marginals a and b. :math:`I_k` is the subset of :math:`[n]` of the i such that :math:`x_i` is in the + partition (or class) :math:`E_k`, i.e. `X_classes[i] == k`. + + This problem is solved by alternating over the variable :math:`\pi` and the variables :math:`\\varphi_i, g_i`. + For :math:`\pi`, the problem is the standard discrete OT problem, and for :math:`\\varphi_i, g_i`, the + problem is a convex QCQP solved using :code:`cvxpy` (ECOS solver). + + Parameters + ---------- + X: array-like (n, d) + reference points used to compute the optimal values phi and G + V: array-like (n, d) + values of the gradients at the reference points X + X_classes : array-like (n,), optional + classes of the reference points, defaults to a single class + a: array-like (n,), optional + weights for the reference points X, defaults to uniform + b: array-like (n,), optional + weights for the target points V, defaults to uniform + strongly_convex_constant : float, optional + constant for the strong convexity of the input potential phi, defaults to 0.6 + gradient_lipschitz_constant : float, optional + constant for the Lipschitz property of the input gradient G, defaults to 1.4 + its: int, optional + number of iterations, defaults to 100 + pbar: bool, optional + if True show a progress bar, defaults to False + log : bool, optional + record log if true + seed: int or RandomState or None, optional + Seed used for random number generator + + Returns + ------- + phi : array-like (n,) + optimal values of the potential at the points X + G : array-like (n, d) + optimal values of the gradients at the points X + log : dict, optional + If input log is true, a dictionary containing the values of the variables at each iteration, as well + as solver information + + References + ---------- + + .. [58] François-Pierre Paty, Alexandre d’Aspremont, and Marco Cuturi. Regularity as regularization: + Smooth and strongly convex brenier potentials in optimal transport. In International Conference + on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020. + + """ + assert X.shape == V.shape, f"point shape should be the same as value shape, yet {X.shape} != {V.shape}" + if X_classes is not None and a is None and b is None: + nx = get_backend(X, V, X_classes) + if X_classes is None and a is not None and b is None: + nx = get_backend(X, V, a) + else: + nx = get_backend(X, V) + assert 0 <= strongly_convex_constant <= gradient_lipschitz_constant, "incompatible regularity assumption" + n, d = X.shape + if X_classes is not None: + assert X_classes.size == n, "incorrect number of class items" + else: + X_classes = nx.zeros(n) + if a is None: + a = ot.unif(n) + if b is None: + b = ot.unif(n) + assert a.size == b.size == n, 'incorrect measure weight sizes' + + if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + G = np.random.randn(n, d) + else: + if seed is not None: + nx.seed(seed) + G = nx.randn(n, d) + + phi = None + log_dict = { + 'G_list': [], + 'phi_list': [], + 'its': [] + } + + for _ in range(its): # alternate optimisation iterations + cost_matrix = dist(G, V) + # optimise the plan + plan = emd(a, b, cost_matrix) + # optimise the values phi and the gradients G + out = solve_nearest_brenier_potential_qcqp(plan, X, X_classes, V, + strongly_convex_constant, gradient_lipschitz_constant, log) + if not log: + phi, G = out + else: + phi, G, it_log_dict = out + log_dict['its'].append(it_log_dict) + log_dict['G_list'].append(G) + log_dict['phi_list'].append(phi) + + if not log: + return phi, G + return phi, G, log_dict + + +def qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant): + r""" + Handy function computing the constants for the Nearest Brenier Potential QCQP problems + + Parameters + ---------- + strongly_convex_constant : float + gradient_lipschitz_constant : float + + Returns + ------- + c1 : float + c2 : float + c3 : float + + """ + c = 1 / (2 * (1 - strongly_convex_constant / gradient_lipschitz_constant)) + c1 = c / gradient_lipschitz_constant + c2 = strongly_convex_constant * c + c3 = 2 * strongly_convex_constant * c / gradient_lipschitz_constant + return c1, c2, c3 + + +def solve_nearest_brenier_potential_qcqp(plan, X, X_classes, V, strongly_convex_constant=0.6, + gradient_lipschitz_constant=1.4, log=False): + r""" + Solves the QCQP problem from `nearest_brenier_potential`, using the method from :ref:`[58]`. + + Parameters + ---------- + plan : array-like (n, n) + fixed OT plan matrix + X: array-like (n, d) + reference points used to compute the optimal values phi and G + X_classes : array-like (n,) + classes of the reference points + V: array-like (n, d) + values of the gradients at the reference points X + strongly_convex_constant : float, optional + constant for the strong convexity of the input potential phi, defaults to 0.6 + gradient_lipschitz_constant : float, optional + constant for the Lipschitz property of the input gradient G, defaults to 1.4 + log : bool, optional + record log if true + + Returns + ------- + phi : array-like (n,) + optimal values of the potential at the points X + G : array-like (n, d) + optimal values of the gradients at the points X + log : dict, optional + If input log is true, a dictionary containing solver information + + References + ---------- + + .. [58] François-Pierre Paty, Alexandre d’Aspremont, and Marco Cuturi. Regularity as regularization: + Smooth and strongly convex brenier potentials in optimal transport. In International Conference + on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020. + + """ + try: + import cvxpy as cvx + except ImportError: + print('Please install CVXPY to use this function') + assert X.shape == V.shape, f"point shape should be the same as value shape, yet {X.shape} != {V.shape}" + assert 0 <= strongly_convex_constant <= gradient_lipschitz_constant, "incompatible regularity assumption" + n, d = X.shape + assert X_classes.size == n, "incorrect number of class items" + assert plan.shape == (n, n), f'plan should be of shape {(n, n)} but is of shape {plan.shape}' + phi = cvx.Variable(n) + G = cvx.Variable((n, d)) + constraints = [] + cost = 0 + for i in range(n): + for j in range(n): + cost += cvx.sum_squares(G[i, :] - V[j, :]) * plan[i, j] + objective = cvx.Minimize(cost) # OT cost + c1, c2, c3 = qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant) + + for k in np.unique(X_classes): # constraints for the convex interpolation + for i in np.where(X_classes == k)[0]: + for j in np.where(X_classes == k)[0]: + constraints += [ + phi[i] >= phi[j] + G[j].T @ (X[i] - X[j]) + c1 * cvx.sum_squares(G[i] - G[j]) \ + + c2 * cvx.sum_squares(X[i] - X[j]) - c3 * (G[j] - G[i]).T @ (X[j] - X[i]) + ] + problem = cvx.Problem(objective, constraints) + problem.solve(solver=cvx.ECOS) + + if not log: + return phi.value, G.value + log_dict = { + 'solve_time': problem.solver_stats.solve_time, + 'setup_time': problem.solver_stats.setup_time, + 'num_iters': problem.solver_stats.num_iters, + 'status': problem.status, + 'value': problem.value + } + return phi.value, G.value, log_dict + + +def bounding_potentials_from_point_values(X, X_classes, phi, G, Y, Y_classes, strongly_convex_constant=0.6, + gradient_lipschitz_constant=1.4, log=False): + r""" + Compute the values of the lower and upper bounding potentials at the input points Y, using the potential optimal + values phi at X and their gradients G at X. The 'lower' potential corresponds to the method from :ref:`[58]`, + Equation 2, while the bounding property and 'upper' potential come from :ref:`[59]`, Theorem 3.14 (taking into + account the fact that this theorem's statement has a min instead of a max, which is a typo). + + If :math:`I_k` is the subset of :math:`[n]` of the i such that :math:`x_i` is in the partition (or class) + :math:`E_k`, for each :math:`y \in E_k`, this function solves the convex QCQP problems, + respectively for l: 'lower' and u: 'upper: + + .. math:: + (\\varphi_{l}(x), \\nabla \\varphi_l(x)) = \\text{argmin}\ t, + + t\in \mathbb{R},\; g\in \mathbb{R}^d, + + \\text{s.t.} \\forall j \in I_k,\; t-\\varphi_j - \langle g_j, y-x_j \\rangle \geq c_1\|g - g_j\|_2^2 + + c_2\|y-x_j\|_2^2 - c_3\langle g_j-g, x_j -y \\rangle. + + .. math:: + (\\varphi_{u}(x), \\nabla \\varphi_u(x)) = \\text{argmax}\ t, + + t\in \mathbb{R},\; g\in \mathbb{R}^d, + + \\text{s.t.} \\forall i \in I_k,\; \\varphi_i^* -t - \langle g, x_i-y \\rangle \geq c_1\|g_i - g\|_2^2 + + c_2\|x_i-y\|_2^2 - c_3\langle g-g_i, y -x_i \\rangle. + + The constants :math:`c_1, c_2, c_3` only depend on `strongly_convex_constant` and `gradient_lipschitz_constant`. + + Parameters + ---------- + X : array-like (n, d) + reference points used to compute the optimal values phi and G + X_classes : array-like (n,) + classes of the reference points + phi : array-like (n,) + optimal values of the potential at the points X + G : array-like (n, d) + optimal values of the gradients at the points X + Y : array-like (m, d) + input points + Y_classes : array_like (m) + classes of the input points + strongly_convex_constant : float, optional + constant for the strong convexity of the input potential phi, defaults to 0.6 + gradient_lipschitz_constant : float, optional + constant for the Lipschitz property of the input gradient G, defaults to 1.4 + log : bool, optional + record log if true + + Returns + ------- + phi_lu: array-like (2, m) + values of the lower and upper bounding potentials at Y + G_lu: array-like (2, m, d) + gradients of the lower and upper bounding potentials at Y + log : dict, optional + If input log is true, a dictionary containing solver information + + References + ---------- + + .. [58] François-Pierre Paty, Alexandre d’Aspremont, and Marco Cuturi. Regularity as regularization: + Smooth and strongly convex brenier potentials in optimal transport. In International Conference + on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020. + + .. [59] Adrien B Taylor. Convex interpolation and performance estimation of first-order methods for + convex optimization. PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, + 2017. + + """ + try: + import cvxpy as cvx + except ImportError: + print('Please install CVXPY to use this function') + m, d = Y.shape + assert Y_classes.size == m, 'wrong number of class items for Y' + assert X.shape[1] == d, f'incompatible dimensions between X: {X.shape} and Y: {Y.shape}' + n, _ = X.shape + assert X_classes.size == n, 'wrong number of class items for X' + c1, c2, c3 = qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant) + phi_lu = np.zeros((2, m)) + G_lu = np.zeros((2, m, d)) + log_dict = {} + + for y_idx in range(m): + log_item = {} + # lower bound + phi_l_y = cvx.Variable(1) + G_l_y = cvx.Variable(d) + objective = cvx.Minimize(phi_l_y) + constraints = [] + k = Y_classes[y_idx] + for j in np.where(X_classes == k)[0]: + constraints += [ + phi_l_y >= phi[j] + G[j].T @ (Y[y_idx] - X[j]) + c1 * cvx.sum_squares(G_l_y - G[j]) \ + + c2 * cvx.sum_squares(Y[y_idx] - X[j]) - c3 * (G[j] - G_l_y).T @ (X[j] - Y[y_idx]) + ] + problem = cvx.Problem(objective, constraints) + problem.solve(solver=cvx.ECOS) + phi_lu[0, y_idx] = phi_l_y.value + G_lu[0, y_idx] = G_l_y.value + if log: + log_item['l'] = { + 'solve_time': problem.solver_stats.solve_time, + 'setup_time': problem.solver_stats.setup_time, + 'num_iters': problem.solver_stats.num_iters, + 'status': problem.status, + 'value': problem.value + } + + # upper bound + phi_u_y = cvx.Variable(1) + G_u_y = cvx.Variable(d) + objective = cvx.Maximize(phi_u_y) + constraints = [] + for i in np.where(X_classes == k)[0]: + constraints += [ + phi[i] >= phi_u_y + G_u_y.T @ (X[i] - Y[y_idx]) + c1 * cvx.sum_squares(G[i] - G_u_y) \ + + c2 * cvx.sum_squares(X[i] - Y[y_idx]) - c3 * (G_u_y - G[i]).T @ (Y[y_idx] - X[i]) + ] + problem = cvx.Problem(objective, constraints) + problem.solve(solver=cvx.ECOS) + phi_lu[1, y_idx] = phi_u_y.value + G_lu[1, y_idx] = G_u_y.value + if log: + log_item['u'] = { + 'solve_time': problem.solver_stats.solve_time, + 'setup_time': problem.solver_stats.setup_time, + 'num_iters': problem.solver_stats.num_iters, + 'status': problem.status, + 'value': problem.value + } + log_dict[y_idx] = log_item + + if not log: + return phi_lu, G_lu + return phi_lu, G_lu, log_dict + diff --git a/ot/sliced.py b/ot/sliced.py index fd86df97f..d5bb0ee08 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -29,7 +29,7 @@ def get_random_projections(d, n_projections, seed=None, backend=None, type_as=No seed: int or RandomState, optional Seed used for numpy random number generator backend: - Backend to ue for random generation + Backend to use for random generation Returns ------- diff --git a/requirements.txt b/requirements.txt index f96e89285..fd39cbab4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ jax jaxlib tensorflow pytest -torch_geometric \ No newline at end of file +torch_geometric +cvxpy \ No newline at end of file From 284c0046bde1880884c9d64d870f51b0e86214b4 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 20 Sep 2023 15:45:11 +0200 Subject: [PATCH 02/21] simplify ssnb function structure --- RELEASES.md | 5 ++ ot/backend.py | 8 +- ot/mapping.py | 171 ++++++++++++++++--------------------------- test/test_backend.py | 8 ++ 4 files changed, 81 insertions(+), 111 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index d0209e233..b522a8ed9 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,10 @@ # Releases +## 0.9.2dev + +- Added support for [Nearest Brenier Potentials (SSNB)](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) (PR #) +- Tweaked `get_backend` to ignore `None` inputs (PR #525) + ## 0.9.1 *August 2023* diff --git a/ot/backend.py b/ot/backend.py index 7b2fe875f..e0eeaf05f 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -150,18 +150,22 @@ def _check_args_backend(backend, args): if len(is_instance) == 1: return is_instance.pop() - # Oterwise return an error + # Otherwise return an error raise ValueError(str_type_error.format([type(a) for a in args])) def get_backend(*args): """Returns the proper backend for a list of input arrays + Accepts None entries in the arguments, and ignores them + Also raises TypeError if all arrays are not from the same backend """ + args = [arg for arg in args if arg is not None] # exclude None entries + # check that some arrays given if not len(args) > 0: - raise ValueError(" The function takes at least one parameter") + raise ValueError(" The function takes at least one (non-None) parameter") for backend in _BACKENDS: if _check_args_backend(backend, args): diff --git a/ot/mapping.py b/ot/mapping.py index 05ae5f1bc..2db2c50c6 100644 --- a/ot/mapping.py +++ b/ot/mapping.py @@ -86,13 +86,13 @@ def nearest_brenier_potential(X, V, X_classes=None, a=None, b=None, strongly_con on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020. """ + try: + import cvxpy as cvx + except ImportError: + print('Please install CVXPY to use this function') + return assert X.shape == V.shape, f"point shape should be the same as value shape, yet {X.shape} != {V.shape}" - if X_classes is not None and a is None and b is None: - nx = get_backend(X, V, X_classes) - if X_classes is None and a is not None and b is None: - nx = get_backend(X, V, a) - else: - nx = get_backend(X, V) + nx = get_backend(X, V, X_classes, a, b) assert 0 <= strongly_convex_constant <= gradient_lipschitz_constant, "incompatible regularity assumption" n, d = X.shape if X_classes is not None: @@ -100,19 +100,19 @@ def nearest_brenier_potential(X, V, X_classes=None, a=None, b=None, strongly_con else: X_classes = nx.zeros(n) if a is None: - a = ot.unif(n) + a = unif(n, type_as=X) if b is None: - b = ot.unif(n) + b = unif(n, type_as=X) assert a.size == b.size == n, 'incorrect measure weight sizes' - if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': - G = np.random.randn(n, d) + if isinstance(seed, np.random.RandomState): + G_val = np.random.randn(n, d) else: if seed is not None: - nx.seed(seed) - G = nx.randn(n, d) + np.random.seed(seed) + G_val = np.random.randn(n, d) - phi = None + phi_val = None log_dict = { 'G_list': [], 'phi_list': [], @@ -124,22 +124,47 @@ def nearest_brenier_potential(X, V, X_classes=None, a=None, b=None, strongly_con # optimise the plan plan = emd(a, b, cost_matrix) # optimise the values phi and the gradients G - out = solve_nearest_brenier_potential_qcqp(plan, X, X_classes, V, - strongly_convex_constant, gradient_lipschitz_constant, log) - if not log: - phi, G = out - else: - phi, G, it_log_dict = out + phi = cvx.Variable(n) + G = cvx.Variable((n, d)) + constraints = [] + cost = 0 + for i in range(n): + for j in range(n): + cost += cvx.sum_squares(G[i, :] - V[j, :]) * plan[i, j] + objective = cvx.Minimize(cost) # OT cost + c1, c2, c3 = ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant) + + for k in np.unique(X_classes): # constraints for the convex interpolation + for i in np.where(X_classes == k)[0]: + for j in np.where(X_classes == k)[0]: + constraints += [ + phi[i] >= phi[j] + G[j].T @ (X[i] - X[j]) + c1 * cvx.sum_squares(G[i] - G[j]) \ + + c2 * cvx.sum_squares(X[i] - X[j]) - c3 * (G[j] - G[i]).T @ (X[j] - X[i]) + ] + problem = cvx.Problem(objective, constraints) + problem.solve(solver=cvx.ECOS) + it_log_dict = { + 'solve_time': problem.solver_stats.solve_time, + 'setup_time': problem.solver_stats.setup_time, + 'num_iters': problem.solver_stats.num_iters, + 'status': problem.status, + 'value': problem.value + } + phi_val, G_val = phi.value, G.value + if log: log_dict['its'].append(it_log_dict) - log_dict['G_list'].append(G) - log_dict['phi_list'].append(phi) + log_dict['G_list'].append(G_val) + log_dict['phi_list'].append(phi_val) + # convert back to backend + phi_val = nx.from_numpy(phi_val) + G_val = nx.from_numpy(G_val) if not log: - return phi, G - return phi, G, log_dict + return phi_val, G_val + return phi_val, G_val, log_dict -def qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant): +def ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant): r""" Handy function computing the constants for the Nearest Brenier Potential QCQP problems @@ -162,88 +187,8 @@ def qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant): return c1, c2, c3 -def solve_nearest_brenier_potential_qcqp(plan, X, X_classes, V, strongly_convex_constant=0.6, - gradient_lipschitz_constant=1.4, log=False): - r""" - Solves the QCQP problem from `nearest_brenier_potential`, using the method from :ref:`[58]`. - - Parameters - ---------- - plan : array-like (n, n) - fixed OT plan matrix - X: array-like (n, d) - reference points used to compute the optimal values phi and G - X_classes : array-like (n,) - classes of the reference points - V: array-like (n, d) - values of the gradients at the reference points X - strongly_convex_constant : float, optional - constant for the strong convexity of the input potential phi, defaults to 0.6 - gradient_lipschitz_constant : float, optional - constant for the Lipschitz property of the input gradient G, defaults to 1.4 - log : bool, optional - record log if true - - Returns - ------- - phi : array-like (n,) - optimal values of the potential at the points X - G : array-like (n, d) - optimal values of the gradients at the points X - log : dict, optional - If input log is true, a dictionary containing solver information - - References - ---------- - - .. [58] François-Pierre Paty, Alexandre d’Aspremont, and Marco Cuturi. Regularity as regularization: - Smooth and strongly convex brenier potentials in optimal transport. In International Conference - on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020. - - """ - try: - import cvxpy as cvx - except ImportError: - print('Please install CVXPY to use this function') - assert X.shape == V.shape, f"point shape should be the same as value shape, yet {X.shape} != {V.shape}" - assert 0 <= strongly_convex_constant <= gradient_lipschitz_constant, "incompatible regularity assumption" - n, d = X.shape - assert X_classes.size == n, "incorrect number of class items" - assert plan.shape == (n, n), f'plan should be of shape {(n, n)} but is of shape {plan.shape}' - phi = cvx.Variable(n) - G = cvx.Variable((n, d)) - constraints = [] - cost = 0 - for i in range(n): - for j in range(n): - cost += cvx.sum_squares(G[i, :] - V[j, :]) * plan[i, j] - objective = cvx.Minimize(cost) # OT cost - c1, c2, c3 = qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant) - - for k in np.unique(X_classes): # constraints for the convex interpolation - for i in np.where(X_classes == k)[0]: - for j in np.where(X_classes == k)[0]: - constraints += [ - phi[i] >= phi[j] + G[j].T @ (X[i] - X[j]) + c1 * cvx.sum_squares(G[i] - G[j]) \ - + c2 * cvx.sum_squares(X[i] - X[j]) - c3 * (G[j] - G[i]).T @ (X[j] - X[i]) - ] - problem = cvx.Problem(objective, constraints) - problem.solve(solver=cvx.ECOS) - - if not log: - return phi.value, G.value - log_dict = { - 'solve_time': problem.solver_stats.solve_time, - 'setup_time': problem.solver_stats.setup_time, - 'num_iters': problem.solver_stats.num_iters, - 'status': problem.status, - 'value': problem.value - } - return phi.value, G.value, log_dict - - -def bounding_potentials_from_point_values(X, X_classes, phi, G, Y, Y_classes, strongly_convex_constant=0.6, - gradient_lipschitz_constant=1.4, log=False): +def bounding_potentials_from_point_values(X, phi, G, Y, X_classes=None, Y_classes=None, + strongly_convex_constant=0.6, gradient_lipschitz_constant=1.4, log=False): r""" Compute the values of the lower and upper bounding potentials at the input points Y, using the potential optimal values phi at X and their gradients G at X. The 'lower' potential corresponds to the method from :ref:`[58]`, @@ -284,8 +229,10 @@ def bounding_potentials_from_point_values(X, X_classes, phi, G, Y, Y_classes, st optimal values of the gradients at the points X Y : array-like (m, d) input points - Y_classes : array_like (m) - classes of the input points + X_classes : array-like (n,), optional + classes of the reference points, defaults to a single class + Y_classes : array_like (m,), optional + classes of the input points, defaults to a single class strongly_convex_constant : float, optional constant for the strong convexity of the input potential phi, defaults to 0.6 gradient_lipschitz_constant : float, optional @@ -318,12 +265,18 @@ def bounding_potentials_from_point_values(X, X_classes, phi, G, Y, Y_classes, st import cvxpy as cvx except ImportError: print('Please install CVXPY to use this function') + return + nx = get_backend(X, X) m, d = Y.shape assert Y_classes.size == m, 'wrong number of class items for Y' assert X.shape[1] == d, f'incompatible dimensions between X: {X.shape} and Y: {Y.shape}' n, _ = X.shape + if X_classes is not None: + assert X_classes.size == n, "incorrect number of class items" + else: + X_classes = nx.zeros(n) assert X_classes.size == n, 'wrong number of class items for X' - c1, c2, c3 = qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant) + c1, c2, c3 = ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant) phi_lu = np.zeros((2, m)) G_lu = np.zeros((2, m, d)) log_dict = {} diff --git a/test/test_backend.py b/test/test_backend.py index f0571471c..00e183dd6 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -753,3 +753,11 @@ def fun(a, b, d): [dl_dw, dl_db] = tape.gradient(manipulated_loss, [w, b]) assert nx.allclose(dl_dw, w) assert nx.allclose(dl_db, b) + + +def test_get_backend_none(): + a, b = np.zeros((2, 3)), None + nx = get_backend(a, b) + assert str(nx) == 'numpy' + with pytest.raises(ValueError): + get_backend(None, None) From cd7be18c1bac05f16a25510a5fe5f4fee23384c4 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 20 Sep 2023 17:26:26 +0200 Subject: [PATCH 03/21] SSNB example --- examples/others/plot_SSNB.py | 110 +++++++++++++++++++++++++++++++++++ ot/__init__.py | 22 +++---- ot/mapping.py | 49 +++++++++------- 3 files changed, 150 insertions(+), 31 deletions(-) create mode 100644 examples/others/plot_SSNB.py diff --git a/examples/others/plot_SSNB.py b/examples/others/plot_SSNB.py new file mode 100644 index 000000000..d7e9b8963 --- /dev/null +++ b/examples/others/plot_SSNB.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +r""" +===================================================== +Smooth and Strongly Convex Nearest Brenier Potentials +===================================================== + +This example is designed to show how to use SSNB [58] in POT. +SSNB computes an l-strongly convex potential :math:`\varphi` with an L-Lipschitz gradient such that +:math:`\nabla \varphi \# \mu \approx \nu`. This regularity can be enforced only on the components of a partition +of the ambient space, which is a relaxation compared to imposing global regularity. + +In this example, we consider a source measure :math:`\mu_s` which is the uniform measure on the unit sphere in +:math:`\mathbb{R}^2`, and the target measure :math:`\mu_t` which is the image of :math:`\mu_x` by +:math:`T(x_1, x_2) = (x_1 + 2\mathrm{sign}(x_2), x_2)`. The map :math:`T` is non-smooth, and we wish to approximate it +using a "Brenier-style" map :math:`\nabla \varphi` which is regular on the partition +:math:`\lbrace x_1 <=0, x_1>0\rbrace`, which is well adapted to this particular dataset. + +We represent the gradients of the "bounding potentials" :math:`\varphi_l, \varphi_u` (from [59], Theorem 3.14), +which bound any SSNB potential which is optimal in the sense of [58], Definition 1: + +.. math:: + \varphi \in \mathrm{argmin}_{\varphi \in \mathcal{F}}\ \mathrm{W}_2(\nabla \varphi \#\mu_s, \mu_t), + +where :math:`\mathcal{F}` is the space functions that are on every set :math:`E_k` l-strongly convex +with an L-Lipschitz gradient, given :math:`(E_k)_{k \in [K]}` a partition of the ambiant source space. + +We perform the optimisation on a low amount of fitting samples and with few iterations, +since solving the SSNB problem is quite computationally expensive. + +.. [58] François-Pierre Paty, Alexandre d’Aspremont, and Marco Cuturi. Regularity as regularization: + Smooth and strongly convex brenier potentials in optimal transport. In International Conference + on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020. + +.. [59] Adrien B Taylor. Convex interpolation and performance estimation of first-order methods for + convex optimization. PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, + 2017. +""" + +# Author: Eloi Tanguy +# License: MIT License + +import matplotlib.pyplot as plt +import numpy as np +import ot +import os + +# %% +# Generating the fitting data +n_fitting_samples = 16 +t = np.linspace(0, 2 * np.pi, n_fitting_samples) +r = 1 +Xs = np.stack([r * np.cos(t), r * np.sin(t)], axis=-1) +Xs_classes = (Xs[:, 0] < 0).astype(int) +Xt = np.stack([Xs[:, 0] + 2 * np.sign(Xs[:, 0]), Xs[:, 1]], axis=-1) + +plt.scatter(Xs[Xs_classes == 0, 0], Xs[Xs_classes == 0, 1], c='blue', label='source class 0') +plt.scatter(Xs[Xs_classes == 1, 0], Xs[Xs_classes == 1, 1], c='dodgerblue', label='source class 1') +plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target') +plt.title('Splitting sphere dataset') +plt.legend(loc='upper right') +plt.show() + +# %% +# Fitting the Nearest Brenier Potential +if not os.path.isfile('/home/eloi/POT_ssnb/examples/others/phi.npy'): + phi, G = ot.nearest_brenier_potential_fit(Xs, Xt, Xs_classes, its=10, seed=0) + np.save('/home/eloi/POT_ssnb/examples/others/phi.npy', phi) + np.save('/home/eloi/POT_ssnb/examples/others/G.npy', G) +else: + phi = np.load('/home/eloi/POT_ssnb/examples/others/phi.npy') + G = np.load('/home/eloi/POT_ssnb/examples/others/G.npy') + +# %% +# Computing the predictions (images by nabla phi) for random samples of the source distribution +rng = np.random.RandomState(seed=0) +n_predict_samples = 100 +t = rng.uniform(0, 2 * np.pi, size=n_predict_samples) +r = rng.uniform(size=n_predict_samples) +Ys = np.stack([r * np.cos(t), r * np.sin(t)], axis=-1) +Ys_classes = (Ys[:, 0] < 0).astype(int) + +if not os.path.isfile('/home/eloi/POT_ssnb/examples/others/phi_lu.npy'): + phi_lu, G_lu = ot.nearest_brenier_potential_predict_bounds(Xs, phi, G, Ys, Xs_classes, Ys_classes) + np.save('/home/eloi/POT_ssnb/examples/others/phi_lu.npy', phi_lu) + np.save('/home/eloi/POT_ssnb/examples/others/G_lu.npy', G_lu) +else: + phi_lu = np.load('/home/eloi/POT_ssnb/examples/others/phi_lu.npy') + G_lu = np.load('/home/eloi/POT_ssnb/examples/others/G_lu.npy') + +# %% +# Plot predictions for the gradient of the lower-bounding potential +plt.clf() +plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source') +plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target') +for i in range(n_predict_samples): + plt.plot([Ys[i, 0], G_lu[0, i, 0]], [Ys[i, 1], G_lu[0, i, 1]], color='black', alpha=.5) +plt.title('Images of new source samples by $\\nabla \\varphi_l$') +plt.legend(loc='upper right') +plt.show() + +# %% +# Plot predictions for the gradient of the upper-bounding potential +plt.clf() +plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source') +plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target') +for i in range(n_predict_samples): + plt.plot([Ys[i, 0], G_lu[1, i, 0]], [Ys[i, 1], G_lu[1, i, 1]], color='black', alpha=.5) +plt.title('Images of new source samples by $\\nabla \\varphi_u$') +plt.legend(loc='upper right') +plt.show() diff --git a/ot/__init__.py b/ot/__init__.py index 8e8c6db6a..a3adfccb0 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -38,25 +38,26 @@ from . import mapping # OT functions -from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, - binary_search_circle, wasserstein_circle, - semidiscrete_wasserstein2_unif_circle) +from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, + binary_search_circle, wasserstein_circle, + semidiscrete_wasserstein2_unif_circle) from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2) from .da import sinkhorn_lpl1_mm -from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance, +from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance, sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif) from .gromov import (gromov_wasserstein, gromov_wasserstein2, - gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) + gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport from .solvers import solve +from .mapping import nearest_brenier_potential_fit, nearest_brenier_potential_predict_bounds # utils functions from .utils import dist, unif, tic, toc, toq -__version__ = "0.9.1" +__version__ = "0.9.2dev" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', @@ -64,9 +65,10 @@ 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere', - 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', - 'max_sliced_wasserstein_distance', 'weak_optimal_transport', - 'factored_optimal_transport', 'solve', + 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', + 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', + 'factored_optimal_transport', 'solve', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', - 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] + 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'nearest_brenier_potential_fit', + 'nearest_brenier_potential_predict_bounds'] diff --git a/ot/mapping.py b/ot/mapping.py index 2db2c50c6..acf502f50 100644 --- a/ot/mapping.py +++ b/ot/mapping.py @@ -7,14 +7,14 @@ # # License: MIT License -from .backend import get_backend +from .backend import get_backend, to_numpy from .lp import emd import numpy as np from .utils import dist, unif -def nearest_brenier_potential(X, V, X_classes=None, a=None, b=None, strongly_convex_constant=.6, - gradient_lipschitz_constant=1.4, its=100, log=False, seed=None): +def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly_convex_constant=.6, + gradient_lipschitz_constant=1.4, its=100, log=False, seed=None): r""" Computes optimal values and gradients at X for a strongly convex potential :math:`\\varphi` with Lipschitz gradients on the partitions defined by `X_classes`, where :math:`\\varphi` is optimal such that @@ -43,6 +43,8 @@ def nearest_brenier_potential(X, V, X_classes=None, a=None, b=None, strongly_con For :math:`\pi`, the problem is the standard discrete OT problem, and for :math:`\\varphi_i, g_i`, the problem is a convex QCQP solved using :code:`cvxpy` (ECOS solver). + Accepts any compatible backend, but will perform the QCQP optimisation on Numpy arrays, and convert back at the end. + Parameters ---------- X: array-like (n, d) @@ -94,6 +96,7 @@ def nearest_brenier_potential(X, V, X_classes=None, a=None, b=None, strongly_con assert X.shape == V.shape, f"point shape should be the same as value shape, yet {X.shape} != {V.shape}" nx = get_backend(X, V, X_classes, a, b) assert 0 <= strongly_convex_constant <= gradient_lipschitz_constant, "incompatible regularity assumption" + X, V = to_numpy(X), to_numpy(V) n, d = X.shape if X_classes is not None: assert X_classes.size == n, "incorrect number of class items" @@ -120,9 +123,10 @@ def nearest_brenier_potential(X, V, X_classes=None, a=None, b=None, strongly_con } for _ in range(its): # alternate optimisation iterations - cost_matrix = dist(G, V) + cost_matrix = dist(G_val, V) # optimise the plan plan = emd(a, b, cost_matrix) + # optimise the values phi and the gradients G phi = cvx.Variable(n) G = cvx.Variable((n, d)) @@ -134,15 +138,16 @@ def nearest_brenier_potential(X, V, X_classes=None, a=None, b=None, strongly_con objective = cvx.Minimize(cost) # OT cost c1, c2, c3 = ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant) - for k in np.unique(X_classes): # constraints for the convex interpolation - for i in np.where(X_classes == k)[0]: - for j in np.where(X_classes == k)[0]: + for k in nx.unique(X_classes): # constraints for the convex interpolation + for i in nx.where(X_classes == k)[0]: + for j in nx.where(X_classes == k)[0]: constraints += [ phi[i] >= phi[j] + G[j].T @ (X[i] - X[j]) + c1 * cvx.sum_squares(G[i] - G[j]) \ + c2 * cvx.sum_squares(X[i] - X[j]) - c3 * (G[j] - G[i]).T @ (X[j] - X[i]) ] problem = cvx.Problem(objective, constraints) problem.solve(solver=cvx.ECOS) + phi_val, G_val = phi.value, G.value it_log_dict = { 'solve_time': problem.solver_stats.solve_time, 'setup_time': problem.solver_stats.setup_time, @@ -150,7 +155,6 @@ def nearest_brenier_potential(X, V, X_classes=None, a=None, b=None, strongly_con 'status': problem.status, 'value': problem.value } - phi_val, G_val = phi.value, G.value if log: log_dict['its'].append(it_log_dict) log_dict['G_list'].append(G_val) @@ -187,8 +191,8 @@ def ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant): return c1, c2, c3 -def bounding_potentials_from_point_values(X, phi, G, Y, X_classes=None, Y_classes=None, - strongly_convex_constant=0.6, gradient_lipschitz_constant=1.4, log=False): +def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_classes=None, + strongly_convex_constant=0.6, gradient_lipschitz_constant=1.4, log=False): r""" Compute the values of the lower and upper bounding potentials at the input points Y, using the potential optimal values phi at X and their gradients G at X. The 'lower' potential corresponds to the method from :ref:`[58]`, @@ -197,7 +201,7 @@ def bounding_potentials_from_point_values(X, phi, G, Y, X_classes=None, Y_classe If :math:`I_k` is the subset of :math:`[n]` of the i such that :math:`x_i` is in the partition (or class) :math:`E_k`, for each :math:`y \in E_k`, this function solves the convex QCQP problems, - respectively for l: 'lower' and u: 'upper: + respectively for l: 'lower' and u: 'upper': .. math:: (\\varphi_{l}(x), \\nabla \\varphi_l(x)) = \\text{argmin}\ t, @@ -266,7 +270,11 @@ def bounding_potentials_from_point_values(X, phi, G, Y, X_classes=None, Y_classe except ImportError: print('Please install CVXPY to use this function') return - nx = get_backend(X, X) + nx = get_backend(X, phi, G, Y) + X = to_numpy(X) + phi = to_numpy(phi) + G = to_numpy(G) + Y = to_numpy(Y) m, d = Y.shape assert Y_classes.size == m, 'wrong number of class items for Y' assert X.shape[1] == d, f'incompatible dimensions between X: {X.shape} and Y: {Y.shape}' @@ -277,8 +285,8 @@ def bounding_potentials_from_point_values(X, phi, G, Y, X_classes=None, Y_classe X_classes = nx.zeros(n) assert X_classes.size == n, 'wrong number of class items for X' c1, c2, c3 = ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant) - phi_lu = np.zeros((2, m)) - G_lu = np.zeros((2, m, d)) + phi_lu = nx.zeros((2, m)) + G_lu = nx.zeros((2, m, d)) log_dict = {} for y_idx in range(m): @@ -289,15 +297,15 @@ def bounding_potentials_from_point_values(X, phi, G, Y, X_classes=None, Y_classe objective = cvx.Minimize(phi_l_y) constraints = [] k = Y_classes[y_idx] - for j in np.where(X_classes == k)[0]: + for j in nx.where(X_classes == k)[0]: constraints += [ phi_l_y >= phi[j] + G[j].T @ (Y[y_idx] - X[j]) + c1 * cvx.sum_squares(G_l_y - G[j]) \ + c2 * cvx.sum_squares(Y[y_idx] - X[j]) - c3 * (G[j] - G_l_y).T @ (X[j] - Y[y_idx]) ] problem = cvx.Problem(objective, constraints) problem.solve(solver=cvx.ECOS) - phi_lu[0, y_idx] = phi_l_y.value - G_lu[0, y_idx] = G_l_y.value + phi_lu[0, y_idx] = nx.from_numpy(phi_l_y.value, type_as=X) + G_lu[0, y_idx] = nx.from_numpy(G_l_y.value, type_as=X) if log: log_item['l'] = { 'solve_time': problem.solver_stats.solve_time, @@ -312,15 +320,15 @@ def bounding_potentials_from_point_values(X, phi, G, Y, X_classes=None, Y_classe G_u_y = cvx.Variable(d) objective = cvx.Maximize(phi_u_y) constraints = [] - for i in np.where(X_classes == k)[0]: + for i in nx.where(X_classes == k)[0]: constraints += [ phi[i] >= phi_u_y + G_u_y.T @ (X[i] - Y[y_idx]) + c1 * cvx.sum_squares(G[i] - G_u_y) \ + c2 * cvx.sum_squares(X[i] - Y[y_idx]) - c3 * (G_u_y - G[i]).T @ (Y[y_idx] - X[i]) ] problem = cvx.Problem(objective, constraints) problem.solve(solver=cvx.ECOS) - phi_lu[1, y_idx] = phi_u_y.value - G_lu[1, y_idx] = G_u_y.value + phi_lu[1, y_idx] = nx.from_numpy(phi_u_y.value, type_as=X) + G_lu[1, y_idx] = nx.from_numpy(G_u_y.value, type_as=X) if log: log_item['u'] = { 'solve_time': problem.solver_stats.solve_time, @@ -334,4 +342,3 @@ def bounding_potentials_from_point_values(X, phi, G, Y, X_classes=None, Y_classe if not log: return phi_lu, G_lu return phi_lu, G_lu, log_dict - From c82e0e60137f1ff52f74c4aa9c5d09382766277e Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 20 Sep 2023 17:36:52 +0200 Subject: [PATCH 04/21] removed numpy saves from example for prod --- examples/others/plot_SSNB.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/examples/others/plot_SSNB.py b/examples/others/plot_SSNB.py index d7e9b8963..0fd414d08 100644 --- a/examples/others/plot_SSNB.py +++ b/examples/others/plot_SSNB.py @@ -62,13 +62,7 @@ # %% # Fitting the Nearest Brenier Potential -if not os.path.isfile('/home/eloi/POT_ssnb/examples/others/phi.npy'): - phi, G = ot.nearest_brenier_potential_fit(Xs, Xt, Xs_classes, its=10, seed=0) - np.save('/home/eloi/POT_ssnb/examples/others/phi.npy', phi) - np.save('/home/eloi/POT_ssnb/examples/others/G.npy', G) -else: - phi = np.load('/home/eloi/POT_ssnb/examples/others/phi.npy') - G = np.load('/home/eloi/POT_ssnb/examples/others/G.npy') +phi, G = ot.nearest_brenier_potential_fit(Xs, Xt, Xs_classes, its=10, seed=0) # %% # Computing the predictions (images by nabla phi) for random samples of the source distribution @@ -78,14 +72,7 @@ r = rng.uniform(size=n_predict_samples) Ys = np.stack([r * np.cos(t), r * np.sin(t)], axis=-1) Ys_classes = (Ys[:, 0] < 0).astype(int) - -if not os.path.isfile('/home/eloi/POT_ssnb/examples/others/phi_lu.npy'): - phi_lu, G_lu = ot.nearest_brenier_potential_predict_bounds(Xs, phi, G, Ys, Xs_classes, Ys_classes) - np.save('/home/eloi/POT_ssnb/examples/others/phi_lu.npy', phi_lu) - np.save('/home/eloi/POT_ssnb/examples/others/G_lu.npy', G_lu) -else: - phi_lu = np.load('/home/eloi/POT_ssnb/examples/others/phi_lu.npy') - G_lu = np.load('/home/eloi/POT_ssnb/examples/others/G_lu.npy') +phi_lu, G_lu = ot.nearest_brenier_potential_predict_bounds(Xs, phi, G, Ys, Xs_classes, Ys_classes) # %% # Plot predictions for the gradient of the lower-bounding potential From 79a99efb83fa6a3ac4a96d247af24d99da56fb6a Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 20 Sep 2023 18:35:24 +0200 Subject: [PATCH 05/21] tests apart from the import exception catch --- test/test_mapping.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 test/test_mapping.py diff --git a/test/test_mapping.py b/test/test_mapping.py new file mode 100644 index 000000000..2685dcea5 --- /dev/null +++ b/test/test_mapping.py @@ -0,0 +1,35 @@ +"""Tests for module mapping""" +# Author: Eloi Tanguy +# +# License: MIT License + +import numpy as np +import pytest +import ot + + +def test_ssnb_qcqp_constants(): + c1, c2, c3 = ot.mapping.ssnb_qcqp_constants(.5, 1) + np.testing.assert_almost_equal(c1, 1) + np.testing.assert_almost_equal(c2, .5) + np.testing.assert_almost_equal(c3, 1) + + +def test_nearest_brenier_potential_fit(nx): + X = nx.ones((2, 2)) + phi, G, log = ot.nearest_brenier_potential_fit(X, X, its=3, log=True) + np.testing.assert_almost_equal(G, X) # image of source should be close to target + # test without log but with X_classes and seed + ot.nearest_brenier_potential_fit(X, X, X_classes=nx.ones(2), its=1, seed=0) + # test with seed being a np.random.RandomState + ot.nearest_brenier_potential_fit(X, X, its=1, seed=np.random.RandomState(seed=0)) + + +def test_brenier_potential_predict_bounds(nx): + X = nx.ones((2, 2)) + phi, G = ot.nearest_brenier_potential_fit(X, X, its=3) + phi_lu, G_lu, log = ot.nearest_brenier_potential_predict_bounds(X, phi, G, X, log=True) + np.testing.assert_almost_equal(G_lu[0], X) # 'new' input isn't new, so should be equal to target + np.testing.assert_almost_equal(G_lu[1], X) + # test with no log but classes + ot.nearest_brenier_potential_predict_bounds(X, phi, G, X, X_classes=nx.ones(2), Y_classes=nx.ones(2)) From 801c280e98bf3fe6aa3152bcd7e3d51c09308332 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 20 Sep 2023 18:35:34 +0200 Subject: [PATCH 06/21] tests apart from the import exception catch --- ot/mapping.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ot/mapping.py b/ot/mapping.py index acf502f50..2e9473f29 100644 --- a/ot/mapping.py +++ b/ot/mapping.py @@ -95,7 +95,6 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly return assert X.shape == V.shape, f"point shape should be the same as value shape, yet {X.shape} != {V.shape}" nx = get_backend(X, V, X_classes, a, b) - assert 0 <= strongly_convex_constant <= gradient_lipschitz_constant, "incompatible regularity assumption" X, V = to_numpy(X), to_numpy(V) n, d = X.shape if X_classes is not None: @@ -184,6 +183,7 @@ def ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant): c3 : float """ + assert 0 < strongly_convex_constant < gradient_lipschitz_constant, "incompatible regularity assumption" c = 1 / (2 * (1 - strongly_convex_constant / gradient_lipschitz_constant)) c1 = c / gradient_lipschitz_constant c2 = strongly_convex_constant * c @@ -276,7 +276,10 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla G = to_numpy(G) Y = to_numpy(Y) m, d = Y.shape - assert Y_classes.size == m, 'wrong number of class items for Y' + if Y_classes is not None: + assert Y_classes.size == m, 'wrong number of class items for Y' + else: + Y_classes = nx.zeros(m) assert X.shape[1] == d, f'incompatible dimensions between X: {X.shape} and Y: {Y.shape}' n, _ = X.shape if X_classes is not None: From ff4697552da9a75771dd4975c2b7439242460b45 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Thu, 21 Sep 2023 11:12:28 +0200 Subject: [PATCH 07/21] da class and tests --- examples/others/plot_SSNB.py | 5 +- ot/da.py | 162 ++++++++++++++++ ot/mapping.py | 27 ++- test/test_da.py | 349 ++++++++++++++++++++--------------- 4 files changed, 387 insertions(+), 156 deletions(-) diff --git a/examples/others/plot_SSNB.py b/examples/others/plot_SSNB.py index 0fd414d08..bc7631554 100644 --- a/examples/others/plot_SSNB.py +++ b/examples/others/plot_SSNB.py @@ -22,11 +22,13 @@ \varphi \in \mathrm{argmin}_{\varphi \in \mathcal{F}}\ \mathrm{W}_2(\nabla \varphi \#\mu_s, \mu_t), where :math:`\mathcal{F}` is the space functions that are on every set :math:`E_k` l-strongly convex -with an L-Lipschitz gradient, given :math:`(E_k)_{k \in [K]}` a partition of the ambiant source space. +with an L-Lipschitz gradient, given :math:`(E_k)_{k \in [K]}` a partition of the ambient source space. We perform the optimisation on a low amount of fitting samples and with few iterations, since solving the SSNB problem is quite computationally expensive. +THIS EXAMPLE REQUIRES CVXPY + .. [58] François-Pierre Paty, Alexandre d’Aspremont, and Marco Cuturi. Regularity as regularization: Smooth and strongly convex brenier potentials in optimal transport. In International Conference on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020. @@ -42,7 +44,6 @@ import matplotlib.pyplot as plt import numpy as np import ot -import os # %% # Generating the fitting data diff --git a/ot/da.py b/ot/da.py index dc6aa70a0..02e323611 100644 --- a/ot/da.py +++ b/ot/da.py @@ -8,6 +8,7 @@ # Michael Perrot # Nathalie Gayraud # Ievgen Redko +# Eloi Tanguy # # License: MIT License @@ -22,6 +23,7 @@ from .gaussian import empirical_bures_wasserstein_mapping, empirical_gaussian_gromov_wasserstein_mapping from .optim import cg from .optim import gcg +from .mapping import nearest_brenier_potential_fit, nearest_brenier_potential_predict_bounds def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, @@ -2635,3 +2637,163 @@ def inverse_transform_labels(self, yt=None): transp_ys.append(nx.dot(D1, transp.T).T) return transp_ys + + +class NearestBrenierPotential(BaseTransport): + r""" + Smooth Strongly Convex Nearest Brenier Potentials (SSNB) is a method from :ref:`[58]` that computes + an l-strongly convex potential :math:`\varphi` with an L-Lipschitz gradient such that + :math:`\nabla \varphi \# \mu \approx \nu`. This regularity can be enforced only on the components of a partition + of the ambient space (encoded by point classes), which is a relaxation compared to imposing global regularity. + + SSNBs approach the target measure by solving the optimisation problem: + + .. math:: + \\varphi \in \\text{argmin}_{\\varphi \in \\mathcal{F}}\ \\text{W}_2(\\nabla \\varphi \#\\mu_s, \\mu_t), + + where :math:`\mathcal{F}` is the space functions that are on every set :math:`E_k` l-strongly convex + with an L-Lipschitz gradient, given :math:`(E_k)_{k \in [K]}` a partition of the ambient source space. + + The problem is solved on "fitting" source and target data via a convex Quadratically Constrained Quadratic Program, + yielding the values :code:`phi` and the gradients :code:`G` at at the source points. + The images of "new" source samples are then found by solving a (simpler) Quadratically Constrained Linear Program + at each point, using the fitting "parameters" :code:`phi` and :code:`G`. We provide two possible images, which + correspond to "lower" and "upper potentials" (:ref:`[59]`, Theorem 3.14). Each of these two images are optimal + solutions of the SSNB problem, and can be used in practice. + + Parameters + ---------- + strongly_convex_constant : float, optional + constant for the strong convexity of the input potential phi, defaults to 0.6 + gradient_lipschitz_constant : float, optional + constant for the Lipschitz property of the input gradient G, defaults to 1.4 + its: int, optional + number of iterations, defaults to 100 + log : bool, optional + record log if true + seed: int or RandomState or None, optional + Seed used for random number generator (for the initialisation in :code:`fit`. + + References + ---------- + + .. [58] François-Pierre Paty, Alexandre d’Aspremont, and Marco Cuturi. Regularity as regularization: + Smooth and strongly convex brenier potentials in optimal transport. In International Conference + on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020. + + .. [59] Adrien B Taylor. Convex interpolation and performance estimation of first-order methods for + convex optimization. PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, + 2017. + + See Also + -------- + ot.mapping.nearest_brenier_potential_fit : Fitting the SSNB on source and target data + ot.mapping.nearest_brenier_potential_predict_bounds : Predicting SSNB images on new source data + """ + def __init__(self, strongly_convex_constant=0.6, gradient_lipschitz_constant=1.4, log=False, its=100, seed=None): + self.strongly_convex_constant = strongly_convex_constant + self.gradient_lipschitz_constant = gradient_lipschitz_constant + self.log = log + self.its = its + self.seed = seed + self.fit_log, self.predict_log = None, None + self.phi, self.G = None, None + self.fit_Xs, self.fit_ys, self.fit_Xt = None, None, None + + def fit(self, Xs=None, ys=None, Xt=None, yt=None): + r""" + Fits the Smooth Strongly Convex Nearest Brenier Potential [58] to the source data :code:`Xs` to the target data + :code:`Xt`, with the partition given by the (optional) labels :code:`ys`. + + Wrapper for :code:`ot.mapping.nearest_brenier_potential_fit`. + + THIS METHOD REQUIRES THE CVXPY LIBRARY + + Parameters + ---------- + Xs : array-like (n, d) + source points used to compute the optimal values phi and G + ys : array-like (n,), optional + classes of the reference points, defaults to a single class + Xt : array-like (n, d) + values of the gradients at the reference points X + yt : optional + ignored. + + Returns + ------- + self : object + Returns self. + + References + ---------- + + .. [58] François-Pierre Paty, Alexandre d’Aspremont, and Marco Cuturi. Regularity as regularization: + Smooth and strongly convex brenier potentials in optimal transport. In International Conference + on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020. + + See Also + -------- + ot.mapping.nearest_brenier_potential_fit : Fitting the SSNB on source and target data + + """ + self.fit_Xs, self.fit_ys, self.fit_Xt = Xs, ys, Xt + returned = nearest_brenier_potential_fit(Xs, Xt, X_classes=ys, + strongly_convex_constant=self.strongly_convex_constant, + gradient_lipschitz_constant=self.gradient_lipschitz_constant, + its=self.its, log=self.log) + + if self.log: + self.phi, self.G, self.fit_log = returned + else: + self.phi, self.G = returned + + return self + + def transform(self, Xs, ys=None): + r""" + Computes the images of the new source samples :code:`Xs` of classes :code:`ys` by the fitted + Smooth Strongly Convex Nearest Brenier Potentials (SSNB) :ref:`[58]`. The output is the images of two SSNB optimal + maps, called 'lower' and 'upper' potentials (from :ref:`[59]`, Theorem 3.14). + + Wrapper for :code:`nearest_brenier_potential_predict_bounds`. + + THIS METHOD REQUIRES THE CVXPY LIBRARY + + Parameters + ---------- + Xs : array-like (m, d) + input source points + ys : : array_like (m,), optional + classes of the input source points, defaults to a single class + + Returns + ------- + G_lu : array-like (2, m, d) + gradients of the lower and upper bounding potentials at Y (images of the source inputs) + + References + ---------- + + .. [58] François-Pierre Paty, Alexandre d’Aspremont, and Marco Cuturi. Regularity as regularization: + Smooth and strongly convex brenier potentials in optimal transport. In International Conference + on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020. + + .. [59] Adrien B Taylor. Convex interpolation and performance estimation of first-order methods for + convex optimization. PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, + 2017. + + See Also + -------- + ot.mapping.nearest_brenier_potential_predict_bounds : Predicting SSNB images on new source data + + """ + returned = nearest_brenier_potential_predict_bounds( + self.fit_Xs, self.phi, self.G, Xs, X_classes=self.fit_ys, Y_classes=ys, + strongly_convex_constant=self.strongly_convex_constant, + gradient_lipschitz_constant=self.gradient_lipschitz_constant, log=self.log) + if self.log: + _, G_lu, self.predict_log = returned + else: + _, G_lu = returned + return G_lu diff --git a/ot/mapping.py b/ot/mapping.py index 2e9473f29..e1bba7481 100644 --- a/ot/mapping.py +++ b/ot/mapping.py @@ -45,17 +45,19 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly Accepts any compatible backend, but will perform the QCQP optimisation on Numpy arrays, and convert back at the end. + THIS FUNCTION REQUIRES THE CVXPY LIBRARY + Parameters ---------- - X: array-like (n, d) + X : array-like (n, d) reference points used to compute the optimal values phi and G - V: array-like (n, d) + V : array-like (n, d) values of the gradients at the reference points X X_classes : array-like (n,), optional classes of the reference points, defaults to a single class - a: array-like (n,), optional + a : array-like (n,), optional weights for the reference points X, defaults to uniform - b: array-like (n,), optional + b : array-like (n,), optional weights for the target points V, defaults to uniform strongly_convex_constant : float, optional constant for the strong convexity of the input potential phi, defaults to 0.6 @@ -63,8 +65,6 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly constant for the Lipschitz property of the input gradient G, defaults to 1.4 its: int, optional number of iterations, defaults to 100 - pbar: bool, optional - if True show a progress bar, defaults to False log : bool, optional record log if true seed: int or RandomState or None, optional @@ -87,6 +87,11 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly Smooth and strongly convex brenier potentials in optimal transport. In International Conference on Artificial Intelligence and Statistics, pages 1222–1232. PMLR, 2020. + See Also + -------- + ot.mapping.nearest_brenier_potential_predict_bounds : Predicting SSNB images on new source data + ot.da.NearestBrenierPotential : BaseTransport wrapper for SSNB + """ try: import cvxpy as cvx @@ -197,7 +202,8 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla Compute the values of the lower and upper bounding potentials at the input points Y, using the potential optimal values phi at X and their gradients G at X. The 'lower' potential corresponds to the method from :ref:`[58]`, Equation 2, while the bounding property and 'upper' potential come from :ref:`[59]`, Theorem 3.14 (taking into - account the fact that this theorem's statement has a min instead of a max, which is a typo). + account the fact that this theorem's statement has a min instead of a max, which is a typo). Both potentials are + optimal for the SSNB problem. If :math:`I_k` is the subset of :math:`[n]` of the i such that :math:`x_i` is in the partition (or class) :math:`E_k`, for each :math:`y \in E_k`, this function solves the convex QCQP problems, @@ -221,6 +227,8 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla The constants :math:`c_1, c_2, c_3` only depend on `strongly_convex_constant` and `gradient_lipschitz_constant`. + THIS FUNCTION REQUIRES THE CVXPY LIBRARY + Parameters ---------- X : array-like (n, d) @@ -264,6 +272,11 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla convex optimization. PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, 2017. + See Also + -------- + ot.mapping.nearest_brenier_potential_fit : Fitting the SSNB on source and target data + ot.da.NearestBrenierPotential : BaseTransport wrapper for SSNB + """ try: import cvxpy as cvx diff --git a/test/test_da.py b/test/test_da.py index dd7d1e0c8..8dcb83411 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -351,69 +351,70 @@ def test_unbalanced_sinkhorn_transport_class(nx): Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) - otda = ot.da.UnbalancedSinkhornTransport() + for log in [True, False]: + otda = ot.da.UnbalancedSinkhornTransport(log=log) - # test its computed - otda.fit(Xs=Xs, Xt=Xt) - assert hasattr(otda, "cost_") - assert hasattr(otda, "coupling_") - assert hasattr(otda, "log_") + # test its computed + otda.fit(Xs=Xs, Xt=Xt) + assert hasattr(otda, "cost_") + assert hasattr(otda, "coupling_") + assert hasattr(otda, "log_") - # test dimensions of coupling - assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) - assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) + # test dimensions of coupling + assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) + assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) - # test transform - transp_Xs = otda.transform(Xs=Xs) - assert_equal(transp_Xs.shape, Xs.shape) + # test transform + transp_Xs = otda.transform(Xs=Xs) + assert_equal(transp_Xs.shape, Xs.shape) - # check label propagation - transp_yt = otda.transform_labels(ys) - assert_equal(transp_yt.shape[0], yt.shape[0]) - assert_equal(transp_yt.shape[1], len(np.unique(ys))) + # check label propagation + transp_yt = otda.transform_labels(ys) + assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) - # check inverse label propagation - transp_ys = otda.inverse_transform_labels(yt) - assert_equal(transp_ys.shape[0], ys.shape[0]) - assert_equal(transp_ys.shape[1], len(np.unique(yt))) + # check inverse label propagation + transp_ys = otda.inverse_transform_labels(yt) + assert_equal(transp_ys.shape[0], ys.shape[0]) + assert_equal(transp_ys.shape[1], len(np.unique(yt))) - Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) - transp_Xs_new = otda.transform(Xs_new) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) + transp_Xs_new = otda.transform(Xs_new) - # check that the oos method is working - assert_equal(transp_Xs_new.shape, Xs_new.shape) + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) - # test inverse transform - transp_Xt = otda.inverse_transform(Xt=Xt) - assert_equal(transp_Xt.shape, Xt.shape) + # test inverse transform + transp_Xt = otda.inverse_transform(Xt=Xt) + assert_equal(transp_Xt.shape, Xt.shape) - Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) - transp_Xt_new = otda.inverse_transform(Xt=Xt_new) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) + transp_Xt_new = otda.inverse_transform(Xt=Xt_new) - # check that the oos method is working - assert_equal(transp_Xt_new.shape, Xt_new.shape) + # check that the oos method is working + assert_equal(transp_Xt_new.shape, Xt_new.shape) - # test fit_transform - transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt) - assert_equal(transp_Xs.shape, Xs.shape) + # test fit_transform + transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt) + assert_equal(transp_Xs.shape, Xs.shape) - # test unsupervised vs semi-supervised mode - otda_unsup = ot.da.SinkhornTransport() - otda_unsup.fit(Xs=Xs, Xt=Xt) - n_unsup = nx.sum(otda_unsup.cost_) + # test unsupervised vs semi-supervised mode + otda_unsup = ot.da.SinkhornTransport() + otda_unsup.fit(Xs=Xs, Xt=Xt) + n_unsup = nx.sum(otda_unsup.cost_) - otda_semi = ot.da.SinkhornTransport() - otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) - assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) - n_semisup = nx.sum(otda_semi.cost_) + otda_semi = ot.da.SinkhornTransport() + otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) + n_semisup = nx.sum(otda_semi.cost_) - # check that the cost matrix norms are indeed different - assert n_unsup != n_semisup, "semisupervised mode not working" + # check that the cost matrix norms are indeed different + assert n_unsup != n_semisup, "semisupervised mode not working" - # check everything runs well with log=True - otda = ot.da.SinkhornTransport(log=True) - otda.fit(Xs=Xs, ys=ys, Xt=Xt) - assert len(otda.log_.keys()) != 0 + # check everything runs well with log=True + otda = ot.da.SinkhornTransport(log=True) + otda.fit(Xs=Xs, ys=ys, Xt=Xt) + assert len(otda.log_.keys()) != 0 @pytest.skip_backend("jax") @@ -585,20 +586,28 @@ def test_linear_mapping_class(nx): Xsb, Xtb = nx.from_numpy(Xs, Xt) - otmap = ot.da.LinearTransport() + for log in [True, False]: + otmap = ot.da.LinearTransport(log=log) - otmap.fit(Xs=Xsb, Xt=Xtb) - assert hasattr(otmap, "A_") - assert hasattr(otmap, "B_") - assert hasattr(otmap, "A1_") - assert hasattr(otmap, "B1_") + otmap.fit(Xs=Xsb, Xt=Xtb) + assert hasattr(otmap, "A_") + assert hasattr(otmap, "B_") + assert hasattr(otmap, "A1_") + assert hasattr(otmap, "B1_") - Xst = nx.to_numpy(otmap.transform(Xs=Xsb)) + Xst = nx.to_numpy(otmap.transform(Xs=Xsb)) - Ct = np.cov(Xt.T) - Cst = np.cov(Xst.T) + Ct = np.cov(Xt.T) + Cst = np.cov(Xst.T) - np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + Xts = nx.to_numpy(otmap.inverse_transform(Xt=Xt)) + + Cs = np.cov(Xs.T) + Cts = np.cov(Xts.T) + + np.testing.assert_allclose(Cs, Cts, rtol=1e-2, atol=1e-2) @pytest.skip_backend("jax") @@ -612,20 +621,21 @@ def test_linear_gw_mapping_class(nx): Xsb, Xtb = nx.from_numpy(Xs, Xt) - otmap = ot.da.LinearGWTransport() + for log in [True, False]: + otmap = ot.da.LinearGWTransport(log=log) - otmap.fit(Xs=Xsb, Xt=Xtb) - assert hasattr(otmap, "A_") - assert hasattr(otmap, "B_") - assert hasattr(otmap, "A1_") - assert hasattr(otmap, "B1_") + otmap.fit(Xs=Xsb, Xt=Xtb) + assert hasattr(otmap, "A_") + assert hasattr(otmap, "B_") + assert hasattr(otmap, "A1_") + assert hasattr(otmap, "B1_") - Xst = nx.to_numpy(otmap.transform(Xs=Xsb)) + Xst = nx.to_numpy(otmap.transform(Xs=Xsb)) - Ct = np.cov(Xt.T) - Cst = np.cov(Xst.T) + Ct = np.cov(Xt.T) + Cst = np.cov(Xst.T) - np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) @pytest.skip_backend("jax") @@ -648,58 +658,60 @@ def test_jcpot_transport_class(nx): Xs = [Xs1, Xs2] ys = [ys1, ys2] - otda = ot.da.JCPOTTransport(reg_e=1, max_iter=10000, tol=1e-9, verbose=True, log=True) + for log in [True, False]: + otda = ot.da.JCPOTTransport(reg_e=1, max_iter=10000, tol=1e-9, verbose=True, log=log) - # test its computed - otda.fit(Xs=Xs, ys=ys, Xt=Xt) + # test its computed + otda.fit(Xs=Xs, ys=ys, Xt=Xt) - assert hasattr(otda, "coupling_") - assert hasattr(otda, "proportions_") - assert hasattr(otda, "log_") + assert hasattr(otda, "coupling_") + assert hasattr(otda, "proportions_") + assert hasattr(otda, "log_") - # test dimensions of coupling - for i, xs in enumerate(Xs): - assert_equal(otda.coupling_[i].shape, ((xs.shape[0], Xt.shape[0]))) + # test dimensions of coupling + for i, xs in enumerate(Xs): + assert_equal(otda.coupling_[i].shape, ((xs.shape[0], Xt.shape[0]))) - # test all margin constraints - mu_t = unif(nt) + # test all margin constraints + mu_t = unif(nt) - for i in range(len(Xs)): - # test margin constraints w.r.t. uniform target weights for each coupling matrix - assert_allclose( - nx.to_numpy(nx.sum(otda.coupling_[i], axis=0)), mu_t, rtol=1e-3, atol=1e-3) + for i in range(len(Xs)): + # test margin constraints w.r.t. uniform target weights for each coupling matrix + assert_allclose( + nx.to_numpy(nx.sum(otda.coupling_[i], axis=0)), mu_t, rtol=1e-3, atol=1e-3) - # test margin constraints w.r.t. modified source weights for each source domain + if log: + # test margin constraints w.r.t. modified source weights for each source domain - assert_allclose( - nx.to_numpy( - nx.dot(otda.log_['D1'][i], nx.sum(otda.coupling_[i], axis=1)) - ), - nx.to_numpy(otda.proportions_), - rtol=1e-3, - atol=1e-3 - ) + assert_allclose( + nx.to_numpy( + nx.dot(otda.log_['D1'][i], nx.sum(otda.coupling_[i], axis=1)) + ), + nx.to_numpy(otda.proportions_), + rtol=1e-3, + atol=1e-3 + ) - # test transform - transp_Xs = otda.transform(Xs=Xs) - [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)] + # test transform + transp_Xs = otda.transform(Xs=Xs) + [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)] - Xs_new = nx.from_numpy(make_data_classif('3gauss', ns1 + 1)[0]) - transp_Xs_new = otda.transform(Xs_new) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns1 + 1)[0]) + transp_Xs_new = otda.transform(Xs_new) - # check that the oos method is working - assert_equal(transp_Xs_new.shape, Xs_new.shape) + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) - # check label propagation - transp_yt = otda.transform_labels(ys) - assert_equal(transp_yt.shape[0], yt.shape[0]) - assert_equal(transp_yt.shape[1], len(np.unique(nx.to_numpy(*ys)))) + # check label propagation + transp_yt = otda.transform_labels(ys) + assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(nx.to_numpy(*ys)))) - # check inverse label propagation - transp_ys = otda.inverse_transform_labels(yt) - for x, y in zip(transp_ys, ys): - assert_equal(x.shape[0], y.shape[0]) - assert_equal(x.shape[1], len(np.unique(nx.to_numpy(y)))) + # check inverse label propagation + transp_ys = otda.inverse_transform_labels(yt) + for x, y in zip(transp_ys, ys): + assert_equal(x.shape[0], y.shape[0]) + assert_equal(x.shape[1], len(np.unique(nx.to_numpy(y)))) def test_jcpot_barycenter(nx): @@ -745,56 +757,99 @@ def test_emd_laplace_class(nx): Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) - otda = ot.da.EMDLaplaceTransport(reg_lap=0.01, max_iter=1000, tol=1e-9, verbose=False, log=True) + for log in [True, False]: + otda = ot.da.EMDLaplaceTransport(reg_lap=0.01, max_iter=1000, tol=1e-9, verbose=False, log=log) - # test its computed - otda.fit(Xs=Xs, ys=ys, Xt=Xt) + # test its computed + otda.fit(Xs=Xs, ys=ys, Xt=Xt) - assert hasattr(otda, "coupling_") - assert hasattr(otda, "log_") + assert hasattr(otda, "coupling_") + assert hasattr(otda, "log_") - # test dimensions of coupling - assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) + # test dimensions of coupling + assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) - # test all margin constraints - mu_s = unif(ns) - mu_t = unif(nt) + # test all margin constraints + mu_s = unif(ns) + mu_t = unif(nt) - assert_allclose( - nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose( - nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) + assert_allclose( + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) + assert_allclose( + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) - # test transform - transp_Xs = otda.transform(Xs=Xs) - [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)] + # test transform + transp_Xs = otda.transform(Xs=Xs) + [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)] - Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) - transp_Xs_new = otda.transform(Xs_new) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) + transp_Xs_new = otda.transform(Xs_new) - # check that the oos method is working - assert_equal(transp_Xs_new.shape, Xs_new.shape) + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) - # test inverse transform - transp_Xt = otda.inverse_transform(Xt=Xt) - assert_equal(transp_Xt.shape, Xt.shape) + # test inverse transform + transp_Xt = otda.inverse_transform(Xt=Xt) + assert_equal(transp_Xt.shape, Xt.shape) - Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) - transp_Xt_new = otda.inverse_transform(Xt=Xt_new) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) + transp_Xt_new = otda.inverse_transform(Xt=Xt_new) - # check that the oos method is working - assert_equal(transp_Xt_new.shape, Xt_new.shape) + # check that the oos method is working + assert_equal(transp_Xt_new.shape, Xt_new.shape) - # test fit_transform - transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt) - assert_equal(transp_Xs.shape, Xs.shape) + # test fit_transform + transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt) + assert_equal(transp_Xs.shape, Xs.shape) - # check label propagation - transp_yt = otda.transform_labels(ys) - assert_equal(transp_yt.shape[0], yt.shape[0]) - assert_equal(transp_yt.shape[1], len(np.unique(nx.to_numpy(ys)))) + # check label propagation + transp_yt = otda.transform_labels(ys) + assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(nx.to_numpy(ys)))) - # check inverse label propagation - transp_ys = otda.inverse_transform_labels(yt) - assert_equal(transp_ys.shape[0], ys.shape[0]) - assert_equal(transp_ys.shape[1], len(np.unique(nx.to_numpy(yt)))) + # check inverse label propagation + transp_ys = otda.inverse_transform_labels(yt) + assert_equal(transp_ys.shape[0], ys.shape[0]) + assert_equal(transp_ys.shape[1], len(np.unique(nx.to_numpy(yt)))) + + +def test_nearest_brenier_potential(nx): + X = nx.ones((2, 2)) + for ssnb in [ot.da.NearestBrenierPotential(log=True), ot.da.NearestBrenierPotential(log=False)]: + ssnb.fit(Xs=X, Xt=X) + G_lu = ssnb.transform(Xs=X) + np.testing.assert_almost_equal(G_lu[0], X) # 'new' input isn't new, so should be equal to target + np.testing.assert_almost_equal(G_lu[1], X) + + +@pytest.mark.skipif(nosklearn, reason="No sklearn available") +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_emd_laplace(nx): + """Complements :code:`test_emd_laplace_class` for uncovered options in :code:`emd_laplace`""" + ns = 50 + nt = 50 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + M = ot.dist(Xs, Xt) + with pytest.raises(ValueError): + ot.da.emd_laplace(ot.unif(ns), ot.unif(nt), Xs, Xt, M, sim_param=['INVALID', 'INPUT', 2]) + with pytest.raises(ValueError): + ot.da.emd_laplace(ot.unif(ns), ot.unif(nt), Xs, Xt, M, sim=['INVALID', 'INPUT', 2]) + + # test all margin constraints with gaussian similarity and disp regularisation + coupling = ot.da.emd_laplace(ot.unif(ns), ot.unif(nt), Xs, Xt, M, sim='gauss', reg='disp') + + assert_allclose( + nx.to_numpy(nx.sum(coupling, axis=0)), unif(nt), rtol=1e-3, atol=1e-3) + assert_allclose( + nx.to_numpy(nx.sum(coupling, axis=1)), unif(ns), rtol=1e-3, atol=1e-3) + + +def test_joint_OT_mapping_verbose(): + xs = np.zeros((2, 1)) + ot.da.joint_OT_mapping_kernel(xs, xs, verbose=True) + ot.da.joint_OT_mapping_linear(xs, xs, verbose=True) From 60944f52f6ab0f3508fe47ae94667cddd0fe0d43 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Thu, 21 Sep 2023 11:16:32 +0200 Subject: [PATCH 08/21] guessed PR number --- RELEASES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index 51f7dc8ad..b6556572d 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -3,7 +3,7 @@ ## 0.9.2dev #### New features -+ Added support for [Nearest Brenier Potentials (SSNB)](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) (PR #) ++ Added support for [Nearest Brenier Potentials (SSNB)](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) (PR #526) + Tweaked `get_backend` to ignore `None` inputs (PR # 525) #### Closed issues From 7bc3213386a9d756648610fab5a6a65224caa462 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Thu, 21 Sep 2023 11:26:13 +0200 Subject: [PATCH 09/21] removed unused import --- RELEASES.md | 2 +- test/test_mapping.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 081164824..3dfd23c3b 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -4,7 +4,7 @@ #### New features + Added support for [Nearest Brenier Potentials (SSNB)](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) (PR #526) -+ Tweaked `get_backend` to ignore `None` inputs (PR # 525) ++ Tweaked `get_backend` to ignore `None` inputs (PR #525) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) diff --git a/test/test_mapping.py b/test/test_mapping.py index 2685dcea5..8119f538c 100644 --- a/test/test_mapping.py +++ b/test/test_mapping.py @@ -4,7 +4,6 @@ # License: MIT License import numpy as np -import pytest import ot From 55f0e095e9569690896ed6888413f3888a4ae32d Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Thu, 21 Sep 2023 11:41:22 +0200 Subject: [PATCH 10/21] PEP8 tab errors fix --- ot/mapping.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ot/mapping.py b/ot/mapping.py index e1bba7481..a275fb1d6 100644 --- a/ot/mapping.py +++ b/ot/mapping.py @@ -31,8 +31,8 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly \\text{s.t.}\ \\forall k \in [K],\; \\forall i,j \in I_k: - \\varphi_i-\\varphi_j-\langle g_j, x_i-x_j\\rangle \geq c_1\|g_i - g_j\|_2^2 + - c_2\|x_i-x_j\|_2^2 - c_3\langle g_j-g_i, x_j -x_i \\rangle. + \\varphi_i-\\varphi_j-\langle g_j, x_i-x_j\\rangle \geq c_1\|g_i - g_j\|_2^2 + + c_2\|x_i-x_j\|_2^2 - c_3\langle g_j-g_i, x_j -x_i \\rangle. The constants :math:`c_1, c_2, c_3` only depend on `strongly_convex_constant` and `gradient_lipschitz_constant`. The constraint :math:`\pi \in \Pi(a, b)` denotes the fact that the matrix :math:`\pi` belong to the OT polytope @@ -214,16 +214,16 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla t\in \mathbb{R},\; g\in \mathbb{R}^d, - \\text{s.t.} \\forall j \in I_k,\; t-\\varphi_j - \langle g_j, y-x_j \\rangle \geq c_1\|g - g_j\|_2^2 - + c_2\|y-x_j\|_2^2 - c_3\langle g_j-g, x_j -y \\rangle. + \\text{s.t.} \\forall j \in I_k,\; t-\\varphi_j - \langle g_j, y-x_j \\rangle \geq c_1\|g - g_j\|_2^2 + + c_2\|y-x_j\|_2^2 - c_3\langle g_j-g, x_j -y \\rangle. .. math:: (\\varphi_{u}(x), \\nabla \\varphi_u(x)) = \\text{argmax}\ t, t\in \mathbb{R},\; g\in \mathbb{R}^d, - \\text{s.t.} \\forall i \in I_k,\; \\varphi_i^* -t - \langle g, x_i-y \\rangle \geq c_1\|g_i - g\|_2^2 - + c_2\|x_i-y\|_2^2 - c_3\langle g-g_i, y -x_i \\rangle. + \\text{s.t.} \\forall i \in I_k,\; \\varphi_i^* -t - \langle g, x_i-y \\rangle \geq c_1\|g_i - g\|_2^2 + + c_2\|x_i-y\|_2^2 - c_3\langle g-g_i, y -x_i \\rangle. The constants :math:`c_1, c_2, c_3` only depend on `strongly_convex_constant` and `gradient_lipschitz_constant`. From 9dfd82ed1ccd8b21764c3afb20a415dfee9d485e Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Thu, 21 Sep 2023 11:44:40 +0200 Subject: [PATCH 11/21] skip ssnb test if no cvxpy --- test/test_da.py | 7 +++++++ test/test_mapping.py | 11 +++++++++++ 2 files changed, 18 insertions(+) diff --git a/test/test_da.py b/test/test_da.py index c3eb25ca6..795ba4b01 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -19,6 +19,12 @@ except ImportError: nosklearn = True +try: # test if cvxpy is installed + import cvxpy # noqa: F401 + nocvxpy = False +except ImportError: + npcvxpy = True + def test_class_jax_tf(): backends = [] @@ -816,6 +822,7 @@ def test_emd_laplace_class(nx): assert_equal(transp_ys.shape[1], len(np.unique(nx.to_numpy(yt)))) +@pytest.mark.skipif(nocvxpy, reason="No CVXPY available") def test_nearest_brenier_potential(nx): X = nx.ones((2, 2)) for ssnb in [ot.da.NearestBrenierPotential(log=True), ot.da.NearestBrenierPotential(log=False)]: diff --git a/test/test_mapping.py b/test/test_mapping.py index 8119f538c..dee3fe913 100644 --- a/test/test_mapping.py +++ b/test/test_mapping.py @@ -5,8 +5,17 @@ import numpy as np import ot +import pytest +try: # test if cvxpy is installed + import cvxpy # noqa: F401 + nocvxpy = False +except ImportError: + npcvxpy = True + + +@pytest.mark.skipif(nocvxpy, reason="No CVXPY available") def test_ssnb_qcqp_constants(): c1, c2, c3 = ot.mapping.ssnb_qcqp_constants(.5, 1) np.testing.assert_almost_equal(c1, 1) @@ -14,6 +23,7 @@ def test_ssnb_qcqp_constants(): np.testing.assert_almost_equal(c3, 1) +@pytest.mark.skipif(nocvxpy, reason="No CVXPY available") def test_nearest_brenier_potential_fit(nx): X = nx.ones((2, 2)) phi, G, log = ot.nearest_brenier_potential_fit(X, X, its=3, log=True) @@ -24,6 +34,7 @@ def test_nearest_brenier_potential_fit(nx): ot.nearest_brenier_potential_fit(X, X, its=1, seed=np.random.RandomState(seed=0)) +@pytest.mark.skipif(nocvxpy, reason="No CVXPY available") def test_brenier_potential_predict_bounds(nx): X = nx.ones((2, 2)) phi, G = ot.nearest_brenier_potential_fit(X, X, its=3) From 048939205c07d5d9d02edbace4ff1db1b263aa7d Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Thu, 21 Sep 2023 15:15:51 +0200 Subject: [PATCH 12/21] test and doc fixes --- docs/source/all.rst | 1 + examples/others/plot_SSNB.py | 1 + ot/da.py | 15 ++++-- ot/mapping.py | 97 ++++++++++++++++++++---------------- test/test_da.py | 4 +- test/test_mapping.py | 8 +-- 6 files changed, 75 insertions(+), 51 deletions(-) diff --git a/docs/source/all.rst b/docs/source/all.rst index 8750074c3..872a48528 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -25,6 +25,7 @@ API and modules gnn gromov lp + mapping optim partial plot diff --git a/examples/others/plot_SSNB.py b/examples/others/plot_SSNB.py index bc7631554..249343b93 100644 --- a/examples/others/plot_SSNB.py +++ b/examples/others/plot_SSNB.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +# sphinx_gallery_thumbnail_number = 2 r""" ===================================================== Smooth and Strongly Convex Nearest Brenier Potentials diff --git a/ot/da.py b/ot/da.py index 02e323611..6ae6f77a9 100644 --- a/ot/da.py +++ b/ot/da.py @@ -2649,7 +2649,11 @@ class NearestBrenierPotential(BaseTransport): SSNBs approach the target measure by solving the optimisation problem: .. math:: - \\varphi \in \\text{argmin}_{\\varphi \in \\mathcal{F}}\ \\text{W}_2(\\nabla \\varphi \#\\mu_s, \\mu_t), + :nowrap: + + \begin{gather*} + \varphi \in \text{argmin}_{\varphi \in \mathcal{F}}\ \text{W}_2(\nabla \varphi \#\mu_s, \mu_t), + \end{gather*} where :math:`\mathcal{F}` is the space functions that are on every set :math:`E_k` l-strongly convex with an L-Lipschitz gradient, given :math:`(E_k)_{k \in [K]}` a partition of the ambient source space. @@ -2661,6 +2665,9 @@ class NearestBrenierPotential(BaseTransport): correspond to "lower" and "upper potentials" (:ref:`[59]`, Theorem 3.14). Each of these two images are optimal solutions of the SSNB problem, and can be used in practice. + .. warning:: This function requires the CVXPY library + .. warning:: Accepts any backend but will convert to Numpy then back to the backend. + Parameters ---------- strongly_convex_constant : float, optional @@ -2707,7 +2714,8 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None): Wrapper for :code:`ot.mapping.nearest_brenier_potential_fit`. - THIS METHOD REQUIRES THE CVXPY LIBRARY + .. warning:: This function requires the CVXPY library + .. warning:: Accepts any backend but will convert to Numpy then back to the backend. Parameters ---------- @@ -2758,7 +2766,8 @@ def transform(self, Xs, ys=None): Wrapper for :code:`nearest_brenier_potential_predict_bounds`. - THIS METHOD REQUIRES THE CVXPY LIBRARY + .. warning:: This function requires the CVXPY library + .. warning:: Accepts any backend but will convert to Numpy then back to the backend. Parameters ---------- diff --git a/ot/mapping.py b/ot/mapping.py index a275fb1d6..028843590 100644 --- a/ot/mapping.py +++ b/ot/mapping.py @@ -16,36 +16,38 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly_convex_constant=.6, gradient_lipschitz_constant=1.4, its=100, log=False, seed=None): r""" - Computes optimal values and gradients at X for a strongly convex potential :math:`\\varphi` with Lipschitz gradients - on the partitions defined by `X_classes`, where :math:`\\varphi` is optimal such that - :math:`\\nabla \\varphi \#\\mu \\approx \\nu`, given samples :math:`X = x_1, \\cdots, x_n \\sim \\mu` and - :math:`V = v_1, \\cdots, v_n \\sim \\nu`. Finding such a potential that has the desired regularity on the + Computes optimal values and gradients at X for a strongly convex potential :math:`\varphi` with Lipschitz gradients + on the partitions defined by `X_classes`, where :math:`\varphi` is optimal such that + :math:`\nabla \varphi \#\mu \approx \nu`, given samples :math:`X = x_1, \cdots, x_n \sim \mu` and + :math:`V = v_1, \cdots, v_n \sim \nu`. Finding such a potential that has the desired regularity on the partition :math:`(E_k)_{k \in [K]}` (given by the classes `X_classes`) is equivalent to finding optimal values - `phi` for the :math:`\\varphi(x_i)` and its gradients :math:`\\nabla \\varphi(x_i)` (variable`G`). + `phi` for the :math:`\varphi(x_i)` and its gradients :math:`\nabla \varphi(x_i)` (variable`G`). In practice, these optimal values are found by solving the following problem .. math:: - \\text{min} \\sum_{i,j}\\pi_{i,j}\|g_i - v_j\|_2^2 + :nowrap: - g_1,\\cdots, g_n \in \mathbb{R}^d,\; \\varphi_1, \\cdots, \\varphi_n \in \mathbb{R},\; \pi \in \Pi(a, b) - - \\text{s.t.}\ \\forall k \in [K],\; \\forall i,j \in I_k: - - \\varphi_i-\\varphi_j-\langle g_j, x_i-x_j\\rangle \geq c_1\|g_i - g_j\|_2^2 + - c_2\|x_i-x_j\|_2^2 - c_3\langle g_j-g_i, x_j -x_i \\rangle. + \begin{gather*} + \text{min} \sum_{i,j}\pi_{i,j}\|g_i - v_j\|_2^2 \\ + g_1,\cdots, g_n \in \mathbb{R}^d,\; \varphi_1, \cdots, \varphi_n \in \mathbb{R},\; \pi \in \Pi(a, b) \\ + \text{s.t.}\ \forall k \in [K],\; \forall i,j \in I_k: \\ + \varphi_i-\varphi_j-\langle g_j, x_i-x_j\rangle \geq c_1\|g_i - g_j\|_2^2 + + c_2\|x_i-x_j\|_2^2 - c_3\langle g_j-g_i, x_j -x_i \rangle. + \end{gather*} The constants :math:`c_1, c_2, c_3` only depend on `strongly_convex_constant` and `gradient_lipschitz_constant`. The constraint :math:`\pi \in \Pi(a, b)` denotes the fact that the matrix :math:`\pi` belong to the OT polytope of marginals a and b. :math:`I_k` is the subset of :math:`[n]` of the i such that :math:`x_i` is in the partition (or class) :math:`E_k`, i.e. `X_classes[i] == k`. - This problem is solved by alternating over the variable :math:`\pi` and the variables :math:`\\varphi_i, g_i`. - For :math:`\pi`, the problem is the standard discrete OT problem, and for :math:`\\varphi_i, g_i`, the + This problem is solved by alternating over the variable :math:`\pi` and the variables :math:`\varphi_i, g_i`. + For :math:`\pi`, the problem is the standard discrete OT problem, and for :math:`\varphi_i, g_i`, the problem is a convex QCQP solved using :code:`cvxpy` (ECOS solver). Accepts any compatible backend, but will perform the QCQP optimisation on Numpy arrays, and convert back at the end. - THIS FUNCTION REQUIRES THE CVXPY LIBRARY + .. warning:: This function requires the CVXPY library + .. warning:: Accepts any backend but will convert to Numpy then back to the backend. Parameters ---------- @@ -103,14 +105,15 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly X, V = to_numpy(X), to_numpy(V) n, d = X.shape if X_classes is not None: + X_classes = to_numpy(X_classes) assert X_classes.size == n, "incorrect number of class items" else: - X_classes = nx.zeros(n) + X_classes = np.zeros(n) if a is None: - a = unif(n, type_as=X) + a = unif(n) if b is None: - b = unif(n, type_as=X) - assert a.size == b.size == n, 'incorrect measure weight sizes' + b = unif(n) + assert a.shape[-1] == b.shape[-1] == n, 'incorrect measure weight sizes' if isinstance(seed, np.random.RandomState): G_val = np.random.randn(n, d) @@ -142,9 +145,9 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly objective = cvx.Minimize(cost) # OT cost c1, c2, c3 = ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant) - for k in nx.unique(X_classes): # constraints for the convex interpolation - for i in nx.where(X_classes == k)[0]: - for j in nx.where(X_classes == k)[0]: + for k in np.unique(X_classes): # constraints for the convex interpolation + for i in np.where(X_classes == k)[0]: + for j in np.where(X_classes == k)[0]: constraints += [ phi[i] >= phi[j] + G[j].T @ (X[i] - X[j]) + c1 * cvx.sum_squares(G[i] - G[j]) \ + c2 * cvx.sum_squares(X[i] - X[j]) - c3 * (G[j] - G[i]).T @ (X[j] - X[i]) @@ -210,24 +213,29 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla respectively for l: 'lower' and u: 'upper': .. math:: - (\\varphi_{l}(x), \\nabla \\varphi_l(x)) = \\text{argmin}\ t, + :nowrap: - t\in \mathbb{R},\; g\in \mathbb{R}^d, - - \\text{s.t.} \\forall j \in I_k,\; t-\\varphi_j - \langle g_j, y-x_j \\rangle \geq c_1\|g - g_j\|_2^2 - + c_2\|y-x_j\|_2^2 - c_3\langle g_j-g, x_j -y \\rangle. + \begin{gather*} + (\varphi_{l}(x), \nabla \varphi_l(x)) = \text{argmin}\ t, \\ + t\in \mathbb{R},\; g\in \mathbb{R}^d, \\ + \text{s.t.} \forall j \in I_k,\; t-\varphi_j - \langle g_j, y-x_j \rangle \geq c_1\|g - g_j\|_2^2 + + c_2\|y-x_j\|_2^2 - c_3\langle g_j-g, x_j -y \rangle. + \end{gather*} .. math:: - (\\varphi_{u}(x), \\nabla \\varphi_u(x)) = \\text{argmax}\ t, - - t\in \mathbb{R},\; g\in \mathbb{R}^d, + :nowrap: - \\text{s.t.} \\forall i \in I_k,\; \\varphi_i^* -t - \langle g, x_i-y \\rangle \geq c_1\|g_i - g\|_2^2 - + c_2\|x_i-y\|_2^2 - c_3\langle g-g_i, y -x_i \\rangle. + \begin{gather*} + (\varphi_{u}(x), \nabla \varphi_u(x)) = \text{argmax}\ t, \\ + t\in \mathbb{R},\; g\in \mathbb{R}^d, \\ + \text{s.t.} \forall i \in I_k,\; \varphi_i^* -t - \langle g, x_i-y \rangle \geq c_1\|g_i - g\|_2^2 + + c_2\|x_i-y\|_2^2 - c_3\langle g-g_i, y -x_i \rangle. + \end{gather*} The constants :math:`c_1, c_2, c_3` only depend on `strongly_convex_constant` and `gradient_lipschitz_constant`. - THIS FUNCTION REQUIRES THE CVXPY LIBRARY + .. warning:: This function requires the CVXPY library + .. warning:: Accepts any backend but will convert to Numpy then back to the backend. Parameters ---------- @@ -290,19 +298,21 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla Y = to_numpy(Y) m, d = Y.shape if Y_classes is not None: + Y_classes = to_numpy(Y_classes) assert Y_classes.size == m, 'wrong number of class items for Y' else: - Y_classes = nx.zeros(m) + Y_classes = np.zeros(m) assert X.shape[1] == d, f'incompatible dimensions between X: {X.shape} and Y: {Y.shape}' n, _ = X.shape if X_classes is not None: + X_classes = to_numpy(X_classes) assert X_classes.size == n, "incorrect number of class items" else: - X_classes = nx.zeros(n) + X_classes = np.zeros(n) assert X_classes.size == n, 'wrong number of class items for X' c1, c2, c3 = ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant) - phi_lu = nx.zeros((2, m)) - G_lu = nx.zeros((2, m, d)) + phi_lu = np.zeros((2, m)) + G_lu = np.zeros((2, m, d)) log_dict = {} for y_idx in range(m): @@ -313,15 +323,15 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla objective = cvx.Minimize(phi_l_y) constraints = [] k = Y_classes[y_idx] - for j in nx.where(X_classes == k)[0]: + for j in np.where(X_classes == k)[0]: constraints += [ phi_l_y >= phi[j] + G[j].T @ (Y[y_idx] - X[j]) + c1 * cvx.sum_squares(G_l_y - G[j]) \ + c2 * cvx.sum_squares(Y[y_idx] - X[j]) - c3 * (G[j] - G_l_y).T @ (X[j] - Y[y_idx]) ] problem = cvx.Problem(objective, constraints) problem.solve(solver=cvx.ECOS) - phi_lu[0, y_idx] = nx.from_numpy(phi_l_y.value, type_as=X) - G_lu[0, y_idx] = nx.from_numpy(G_l_y.value, type_as=X) + phi_lu[0, y_idx] = phi_l_y.value + G_lu[0, y_idx] = G_l_y.value if log: log_item['l'] = { 'solve_time': problem.solver_stats.solve_time, @@ -336,15 +346,15 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla G_u_y = cvx.Variable(d) objective = cvx.Maximize(phi_u_y) constraints = [] - for i in nx.where(X_classes == k)[0]: + for i in np.where(X_classes == k)[0]: constraints += [ phi[i] >= phi_u_y + G_u_y.T @ (X[i] - Y[y_idx]) + c1 * cvx.sum_squares(G[i] - G_u_y) \ + c2 * cvx.sum_squares(X[i] - Y[y_idx]) - c3 * (G_u_y - G[i]).T @ (Y[y_idx] - X[i]) ] problem = cvx.Problem(objective, constraints) problem.solve(solver=cvx.ECOS) - phi_lu[1, y_idx] = nx.from_numpy(phi_u_y.value, type_as=X) - G_lu[1, y_idx] = nx.from_numpy(G_u_y.value, type_as=X) + phi_lu[1, y_idx] = phi_u_y.value + G_lu[1, y_idx] = G_u_y.value if log: log_item['u'] = { 'solve_time': problem.solver_stats.solve_time, @@ -355,6 +365,7 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla } log_dict[y_idx] = log_item + phi_lu, G_lu = nx.from_numpy(phi_lu), nx.from_numpy(G_lu) if not log: return phi_lu, G_lu return phi_lu, G_lu, log_dict diff --git a/test/test_da.py b/test/test_da.py index 795ba4b01..d3ce317e9 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -611,7 +611,7 @@ def test_linear_mapping_class(nx): np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) - Xts = nx.to_numpy(otmap.inverse_transform(Xt=Xt)) + Xts = nx.to_numpy(otmap.inverse_transform(Xt=Xtb)) Cs = np.cov(Xs.T) Cts = np.cov(Xts.T) @@ -851,7 +851,7 @@ def test_emd_laplace(nx): ot.da.emd_laplace(ot.unif(ns), ot.unif(nt), Xs, Xt, M, sim=['INVALID', 'INPUT', 2]) # test all margin constraints with gaussian similarity and disp regularisation - coupling = ot.da.emd_laplace(ot.unif(ns), ot.unif(nt), Xs, Xt, M, sim='gauss', reg='disp') + coupling = ot.da.emd_laplace(ot.unif(ns, type_as=Xs), ot.unif(nt, type_as=Xs), Xs, Xt, M, sim='gauss', reg='disp') assert_allclose( nx.to_numpy(nx.sum(coupling, axis=0)), unif(nt), rtol=1e-3, atol=1e-3) diff --git a/test/test_mapping.py b/test/test_mapping.py index dee3fe913..7b11d0094 100644 --- a/test/test_mapping.py +++ b/test/test_mapping.py @@ -6,6 +6,7 @@ import numpy as np import ot import pytest +from ot.backend import to_numpy try: # test if cvxpy is installed @@ -27,7 +28,7 @@ def test_ssnb_qcqp_constants(): def test_nearest_brenier_potential_fit(nx): X = nx.ones((2, 2)) phi, G, log = ot.nearest_brenier_potential_fit(X, X, its=3, log=True) - np.testing.assert_almost_equal(G, X) # image of source should be close to target + np.testing.assert_almost_equal(to_numpy(G), to_numpy(X)) # image of source should be close to target # test without log but with X_classes and seed ot.nearest_brenier_potential_fit(X, X, X_classes=nx.ones(2), its=1, seed=0) # test with seed being a np.random.RandomState @@ -39,7 +40,8 @@ def test_brenier_potential_predict_bounds(nx): X = nx.ones((2, 2)) phi, G = ot.nearest_brenier_potential_fit(X, X, its=3) phi_lu, G_lu, log = ot.nearest_brenier_potential_predict_bounds(X, phi, G, X, log=True) - np.testing.assert_almost_equal(G_lu[0], X) # 'new' input isn't new, so should be equal to target - np.testing.assert_almost_equal(G_lu[1], X) + # 'new' input isn't new, so should be equal to target + np.testing.assert_almost_equal(to_numpy(G_lu[0]), to_numpy(X)) + np.testing.assert_almost_equal(to_numpy(G_lu[1]), to_numpy(X)) # test with no log but classes ot.nearest_brenier_potential_predict_bounds(X, phi, G, X, X_classes=nx.ones(2), Y_classes=nx.ones(2)) From 2adcab311615f2fcedb7ea4c76fd38164144ecdd Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Thu, 21 Sep 2023 15:30:28 +0200 Subject: [PATCH 13/21] doc dependency + minor comment in ot __init__.py --- docs/requirements_rtd.txt | 3 ++- ot/__init__.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/requirements_rtd.txt b/docs/requirements_rtd.txt index 30082bb50..323d95134 100644 --- a/docs/requirements_rtd.txt +++ b/docs/requirements_rtd.txt @@ -11,4 +11,5 @@ matplotlib autograd pymanopt cvxopt -scikit-learn \ No newline at end of file +scikit-learn +cvxpy \ No newline at end of file diff --git a/ot/__init__.py b/ot/__init__.py index a3adfccb0..6f3457f90 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -5,7 +5,7 @@ :py:mod:`ot.utils`, :py:mod:`ot.datasets`, :py:mod:`ot.gromov`, :py:mod:`ot.smooth` :py:mod:`ot.stochastic`, :py:mod:`ot.partial`, :py:mod:`ot.regpath` - , :py:mod:`ot.unbalanced`. + , :py:mod:`ot.unbalanced`, :py:mod`ot.mapping`. The following sub-modules are not imported due to additional dependencies: - :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`. - :any:`ot.plot` : depends on :code:`matplotlib` From 3e2e5b8d8ae3e1e23a2af912939d24e3f619cc36 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Thu, 21 Sep 2023 15:42:25 +0200 Subject: [PATCH 14/21] PEP8 fixes --- ot/mapping.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ot/mapping.py b/ot/mapping.py index 028843590..b382cc309 100644 --- a/ot/mapping.py +++ b/ot/mapping.py @@ -149,7 +149,7 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly for i in np.where(X_classes == k)[0]: for j in np.where(X_classes == k)[0]: constraints += [ - phi[i] >= phi[j] + G[j].T @ (X[i] - X[j]) + c1 * cvx.sum_squares(G[i] - G[j]) \ + phi[i] >= phi[j] + G[j].T @ (X[i] - X[j]) + c1 * cvx.sum_squares(G[i] - G[j]) + c2 * cvx.sum_squares(X[i] - X[j]) - c3 * (G[j] - G[i]).T @ (X[j] - X[i]) ] problem = cvx.Problem(objective, constraints) @@ -325,7 +325,7 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla k = Y_classes[y_idx] for j in np.where(X_classes == k)[0]: constraints += [ - phi_l_y >= phi[j] + G[j].T @ (Y[y_idx] - X[j]) + c1 * cvx.sum_squares(G_l_y - G[j]) \ + phi_l_y >= phi[j] + G[j].T @ (Y[y_idx] - X[j]) + c1 * cvx.sum_squares(G_l_y - G[j]) + c2 * cvx.sum_squares(Y[y_idx] - X[j]) - c3 * (G[j] - G_l_y).T @ (X[j] - Y[y_idx]) ] problem = cvx.Problem(objective, constraints) @@ -348,7 +348,7 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla constraints = [] for i in np.where(X_classes == k)[0]: constraints += [ - phi[i] >= phi_u_y + G_u_y.T @ (X[i] - Y[y_idx]) + c1 * cvx.sum_squares(G[i] - G_u_y) \ + phi[i] >= phi_u_y + G_u_y.T @ (X[i] - Y[y_idx]) + c1 * cvx.sum_squares(G[i] - G_u_y) + c2 * cvx.sum_squares(X[i] - Y[y_idx]) - c3 * (G_u_y - G[i]).T @ (Y[y_idx] - X[i]) ] problem = cvx.Problem(objective, constraints) From 596edd4bed8f43eaebd11d2bcb4947ba4b51d15d Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Thu, 21 Sep 2023 15:43:37 +0200 Subject: [PATCH 15/21] test typo fix --- test/test_da.py | 2 +- test/test_mapping.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_da.py b/test/test_da.py index bcc1275b2..a78fc4048 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -25,7 +25,7 @@ nocvxpy = False except ImportError: - npcvxpy = True + nocvxpy = True def test_class_jax_tf(): diff --git a/test/test_mapping.py b/test/test_mapping.py index 7b11d0094..6846a25ba 100644 --- a/test/test_mapping.py +++ b/test/test_mapping.py @@ -13,7 +13,7 @@ import cvxpy # noqa: F401 nocvxpy = False except ImportError: - npcvxpy = True + nocvxpy = True @pytest.mark.skipif(nocvxpy, reason="No CVXPY available") From 80fa0b9124a843501cc470b2bcc7068d3399ff8c Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Thu, 21 Sep 2023 15:57:39 +0200 Subject: [PATCH 16/21] ssnb da backend test fix --- test/test_da.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_da.py b/test/test_da.py index a78fc4048..843a47d2d 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -830,8 +830,9 @@ def test_nearest_brenier_potential(nx): for ssnb in [ot.da.NearestBrenierPotential(log=True), ot.da.NearestBrenierPotential(log=False)]: ssnb.fit(Xs=X, Xt=X) G_lu = ssnb.transform(Xs=X) - np.testing.assert_almost_equal(G_lu[0], X) # 'new' input isn't new, so should be equal to target - np.testing.assert_almost_equal(G_lu[1], X) + # 'new' input isn't new, so should be equal to target + np.testing.assert_almost_equal(nx.to_numpy(G_lu[0]), nx.to_numpy(X)) + np.testing.assert_almost_equal(nx.to_numpy(G_lu[1]), nx.to_numpy(X)) @pytest.mark.skipif(nosklearn, reason="No sklearn available") From 0a349cec74487db02823217a438fcfcc23d741c2 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Thu, 21 Sep 2023 17:06:25 +0200 Subject: [PATCH 17/21] moved joint ot mappings to the mapping module --- ot/__init__.py | 5 +- ot/da.py | 423 +----------------------------------------- ot/mapping.py | 425 ++++++++++++++++++++++++++++++++++++++++++- test/test_da.py | 6 - test/test_mapping.py | 6 + 5 files changed, 435 insertions(+), 430 deletions(-) diff --git a/ot/__init__.py b/ot/__init__.py index 6f3457f90..58445c2cf 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -52,7 +52,8 @@ from .weak import weak_optimal_transport from .factored import factored_optimal_transport from .solvers import solve -from .mapping import nearest_brenier_potential_fit, nearest_brenier_potential_predict_bounds +from .mapping import (nearest_brenier_potential_fit, nearest_brenier_potential_predict_bounds, joint_OT_mapping_kernel, + joint_OT_mapping_linear) # utils functions from .utils import dist, unif, tic, toc, toq @@ -71,4 +72,4 @@ 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'nearest_brenier_potential_fit', - 'nearest_brenier_potential_predict_bounds'] + 'nearest_brenier_potential_predict_bounds', 'joint_OT_mapping_kernel', 'joint_OT_mapping_linear'] diff --git a/ot/da.py b/ot/da.py index 18ea1a72e..8764268f0 100644 --- a/ot/da.py +++ b/ot/da.py @@ -23,7 +23,8 @@ from .gaussian import empirical_bures_wasserstein_mapping, empirical_gaussian_gromov_wasserstein_mapping from .optim import cg from .optim import gcg -from .mapping import nearest_brenier_potential_fit, nearest_brenier_potential_predict_bounds +from .mapping import nearest_brenier_potential_fit, nearest_brenier_potential_predict_bounds, joint_OT_mapping_linear, \ + joint_OT_mapping_kernel def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, @@ -258,426 +259,6 @@ def df(G): verbose=verbose, log=log) -def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, - verbose2=False, numItermax=100, numInnerItermax=10, - stopInnerThr=1e-6, stopThr=1e-5, log=False, - **kwargs): - r"""Joint OT and linear mapping estimation as proposed in - :ref:`[8] `. - - The function solves the following optimization problem: - - .. math:: - \min_{\gamma,L}\quad \|L(\mathbf{X_s}) - n_s\gamma \mathbf{X_t} \|^2_F + - \mu \langle \gamma, \mathbf{M} \rangle_F + \eta \|L - \mathbf{I}\|^2_F - - s.t. \ \gamma \mathbf{1} = \mathbf{a} - - \gamma^T \mathbf{1} = \mathbf{b} - - \gamma \geq 0 - - where : - - - :math:`\mathbf{M}` is the (`ns`, `nt`) squared euclidean cost matrix between samples in - :math:`\mathbf{X_s}` and :math:`\mathbf{X_t}` (scaled by :math:`n_s`) - - :math:`L` is a :math:`d\times d` linear operator that approximates the barycentric - mapping - - :math:`\mathbf{I}` is the identity matrix (neutral linear mapping) - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are uniform source and target weights - - The problem consist in solving jointly an optimal transport matrix - :math:`\gamma` and a linear mapping that fits the barycentric mapping - :math:`n_s\gamma \mathbf{X_t}`. - - One can also estimate a mapping with constant bias (see supplementary - material of :ref:`[8] `) using the bias optional argument. - - The algorithm used for solving the problem is the block coordinate - descent that alternates between updates of :math:`\mathbf{G}` (using conditional gradient) - and the update of :math:`\mathbf{L}` using a classical least square solver. - - - Parameters - ---------- - xs : array-like (ns,d) - samples in the source domain - xt : array-like (nt,d) - samples in the target domain - mu : float,optional - Weight for the linear OT loss (>0) - eta : float, optional - Regularization term for the linear mapping L (>0) - bias : bool,optional - Estimate linear mapping with constant bias - numItermax : int, optional - Max number of BCD iterations - stopThr : float, optional - Stop threshold on relative loss decrease (>0) - numInnerItermax : int, optional - Max number of iterations (inner CG solver) - stopInnerThr : float, optional - Stop threshold on error (inner CG solver) (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - - - Returns - ------- - gamma : (ns, nt) array-like - Optimal transportation matrix for the given parameters - L : (d, d) array-like - Linear mapping matrix ((:math:`d+1`, `d`) if bias) - log : dict - log dictionary return only if log==True in parameters - - - .. _references-joint-OT-mapping-linear: - References - ---------- - .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, - "Mapping estimation for discrete optimal transport", - Neural Information Processing Systems (NIPS), 2016. - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - - """ - xs, xt = list_to_array(xs, xt) - nx = get_backend(xs, xt) - - ns, nt, d = xs.shape[0], xt.shape[0], xt.shape[1] - - if bias: - xs1 = nx.concatenate((xs, nx.ones((ns, 1), type_as=xs)), axis=1) - xstxs = nx.dot(xs1.T, xs1) - Id = nx.eye(d + 1, type_as=xs) - Id[-1] = 0 - I0 = Id[:, :-1] - - def sel(x): - return x[:-1, :] - else: - xs1 = xs - xstxs = nx.dot(xs1.T, xs1) - Id = nx.eye(d, type_as=xs) - I0 = Id - - def sel(x): - return x - - if log: - log = {'err': []} - - a = unif(ns, type_as=xs) - b = unif(nt, type_as=xt) - M = dist(xs, xt) * ns - G = emd(a, b, M) - - vloss = [] - - def loss(L, G): - """Compute full loss""" - return ( - nx.sum((nx.dot(xs1, L) - ns * nx.dot(G, xt)) ** 2) - + mu * nx.sum(G * M) - + eta * nx.sum(sel(L - I0) ** 2) - ) - - def solve_L(G): - """ solve L problem with fixed G (least square)""" - xst = ns * nx.dot(G, xt) - return nx.solve(xstxs + eta * Id, nx.dot(xs1.T, xst) + eta * I0) - - def solve_G(L, G0): - """Update G with CG algorithm""" - xsi = nx.dot(xs1, L) - - def f(G): - return nx.sum((xsi - ns * nx.dot(G, xt)) ** 2) - - def df(G): - return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T) - - G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, - numItermax=numInnerItermax, stopThr=stopInnerThr) - return G - - L = solve_L(G) - - vloss.append(loss(L, G)) - - if verbose: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format(0, vloss[-1], 0)) - - # init loop - if numItermax > 0: - loop = 1 - else: - loop = 0 - it = 0 - - while loop: - - it += 1 - - # update G - G = solve_G(L, G) - - # update L - L = solve_L(G) - - vloss.append(loss(L, G)) - - if it >= numItermax: - loop = 0 - - if abs(vloss[-1] - vloss[-2]) / abs(vloss[-2]) < stopThr: - loop = 0 - - if verbose: - if it % 20 == 0: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format( - it, vloss[-1], (vloss[-1] - vloss[-2]) / abs(vloss[-2]))) - if log: - log['loss'] = vloss - return G, L, log - else: - return G, L - - -def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', - sigma=1, bias=False, verbose=False, verbose2=False, - numItermax=100, numInnerItermax=10, - stopInnerThr=1e-6, stopThr=1e-5, log=False, - **kwargs): - r"""Joint OT and nonlinear mapping estimation with kernels as proposed in - :ref:`[8] `. - - The function solves the following optimization problem: - - .. math:: - \min_{\gamma, L\in\mathcal{H}}\quad \|L(\mathbf{X_s}) - - n_s\gamma \mathbf{X_t}\|^2_F + \mu \langle \gamma, \mathbf{M} \rangle_F + - \eta \|L\|^2_\mathcal{H} - - s.t. \ \gamma \mathbf{1} = \mathbf{a} - - \gamma^T \mathbf{1} = \mathbf{b} - - \gamma \geq 0 - - - where : - - - :math:`\mathbf{M}` is the (`ns`, `nt`) squared euclidean cost matrix between samples in - :math:`\mathbf{X_s}` and :math:`\mathbf{X_t}` (scaled by :math:`n_s`) - - :math:`L` is a :math:`n_s \times d` linear operator on a kernel matrix that - approximates the barycentric mapping - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are uniform source and target weights - - The problem consist in solving jointly an optimal transport matrix - :math:`\gamma` and the nonlinear mapping that fits the barycentric mapping - :math:`n_s\gamma \mathbf{X_t}`. - - One can also estimate a mapping with constant bias (see supplementary - material of :ref:`[8] `) using the bias optional argument. - - The algorithm used for solving the problem is the block coordinate - descent that alternates between updates of :math:`\mathbf{G}` (using conditional gradient) - and the update of :math:`\mathbf{L}` using a classical kernel least square solver. - - - Parameters - ---------- - xs : array-like (ns,d) - samples in the source domain - xt : array-like (nt,d) - samples in the target domain - mu : float,optional - Weight for the linear OT loss (>0) - eta : float, optional - Regularization term for the linear mapping L (>0) - kerneltype : str,optional - kernel used by calling function :py:func:`ot.utils.kernel` (gaussian by default) - sigma : float, optional - Gaussian kernel bandwidth. - bias : bool,optional - Estimate linear mapping with constant bias - verbose : bool, optional - Print information along iterations - verbose2 : bool, optional - Print information along iterations - numItermax : int, optional - Max number of BCD iterations - numInnerItermax : int, optional - Max number of iterations (inner CG solver) - stopInnerThr : float, optional - Stop threshold on error (inner CG solver) (>0) - stopThr : float, optional - Stop threshold on relative loss decrease (>0) - log : bool, optional - record log if True - - - Returns - ------- - gamma : (ns, nt) array-like - Optimal transportation matrix for the given parameters - L : (ns, d) array-like - Nonlinear mapping matrix ((:math:`n_s+1`, `d`) if bias) - log : dict - log dictionary return only if log==True in parameters - - - .. _references-joint-OT-mapping-kernel: - References - ---------- - .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, - "Mapping estimation for discrete optimal transport", - Neural Information Processing Systems (NIPS), 2016. - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - - """ - xs, xt = list_to_array(xs, xt) - nx = get_backend(xs, xt) - - ns, nt = xs.shape[0], xt.shape[0] - - K = kernel(xs, xs, method=kerneltype, sigma=sigma) - if bias: - K1 = nx.concatenate((K, nx.ones((ns, 1), type_as=xs)), axis=1) - Id = nx.eye(ns + 1, type_as=xs) - Id[-1] = 0 - Kp = nx.eye(ns + 1, type_as=xs) - Kp[:ns, :ns] = K - - # ls regu - # K0 = K1.T.dot(K1)+eta*I - # Kreg=I - - # RKHS regul - K0 = nx.dot(K1.T, K1) + eta * Kp - Kreg = Kp - - else: - K1 = K - Id = nx.eye(ns, type_as=xs) - - # ls regul - # K0 = K1.T.dot(K1)+eta*I - # Kreg=I - - # proper kernel ridge - K0 = K + eta * Id - Kreg = K - - if log: - log = {'err': []} - - a = unif(ns, type_as=xs) - b = unif(nt, type_as=xt) - M = dist(xs, xt) * ns - G = emd(a, b, M) - - vloss = [] - - def loss(L, G): - """Compute full loss""" - return ( - nx.sum((nx.dot(K1, L) - ns * nx.dot(G, xt)) ** 2) - + mu * nx.sum(G * M) - + eta * nx.trace(dots(L.T, Kreg, L)) - ) - - def solve_L_nobias(G): - """ solve L problem with fixed G (least square)""" - xst = ns * nx.dot(G, xt) - return nx.solve(K0, xst) - - def solve_L_bias(G): - """ solve L problem with fixed G (least square)""" - xst = ns * nx.dot(G, xt) - return nx.solve(K0, nx.dot(K1.T, xst)) - - def solve_G(L, G0): - """Update G with CG algorithm""" - xsi = nx.dot(K1, L) - - def f(G): - return nx.sum((xsi - ns * nx.dot(G, xt)) ** 2) - - def df(G): - return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T) - - G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, - numItermax=numInnerItermax, stopThr=stopInnerThr) - return G - - if bias: - solve_L = solve_L_bias - else: - solve_L = solve_L_nobias - - L = solve_L(G) - - vloss.append(loss(L, G)) - - if verbose: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format(0, vloss[-1], 0)) - - # init loop - if numItermax > 0: - loop = 1 - else: - loop = 0 - it = 0 - - while loop: - - it += 1 - - # update G - G = solve_G(L, G) - - # update L - L = solve_L(G) - - vloss.append(loss(L, G)) - - if it >= numItermax: - loop = 0 - - if abs(vloss[-1] - vloss[-2]) / abs(vloss[-2]) < stopThr: - loop = 0 - - if verbose: - if it % 20 == 0: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format( - it, vloss[-1], (vloss[-1] - vloss[-2]) / abs(vloss[-2]))) - if log: - log['loss'] = vloss - return G, L, log - else: - return G, L - - OT_mapping_linear = deprecated(empirical_bures_wasserstein_mapping) diff --git a/ot/mapping.py b/ot/mapping.py index b382cc309..76bc1f308 100644 --- a/ot/mapping.py +++ b/ot/mapping.py @@ -4,13 +4,16 @@ """ # Author: Eloi Tanguy +# Remi Flamary # # License: MIT License from .backend import get_backend, to_numpy from .lp import emd import numpy as np -from .utils import dist, unif + +from .optim import cg +from .utils import dist, unif, list_to_array, kernel, dots def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly_convex_constant=.6, @@ -369,3 +372,423 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla if not log: return phi_lu, G_lu return phi_lu, G_lu, log_dict + + +def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, + verbose2=False, numItermax=100, numInnerItermax=10, + stopInnerThr=1e-6, stopThr=1e-5, log=False, + **kwargs): + r"""Joint OT and linear mapping estimation as proposed in + :ref:`[8] `. + + The function solves the following optimization problem: + + .. math:: + \min_{\gamma,L}\quad \|L(\mathbf{X_s}) - n_s\gamma \mathbf{X_t} \|^2_F + + \mu \langle \gamma, \mathbf{M} \rangle_F + \eta \|L - \mathbf{I}\|^2_F + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 + + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) squared euclidean cost matrix between samples in + :math:`\mathbf{X_s}` and :math:`\mathbf{X_t}` (scaled by :math:`n_s`) + - :math:`L` is a :math:`d\times d` linear operator that approximates the barycentric + mapping + - :math:`\mathbf{I}` is the identity matrix (neutral linear mapping) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are uniform source and target weights + + The problem consist in solving jointly an optimal transport matrix + :math:`\gamma` and a linear mapping that fits the barycentric mapping + :math:`n_s\gamma \mathbf{X_t}`. + + One can also estimate a mapping with constant bias (see supplementary + material of :ref:`[8] `) using the bias optional argument. + + The algorithm used for solving the problem is the block coordinate + descent that alternates between updates of :math:`\mathbf{G}` (using conditional gradient) + and the update of :math:`\mathbf{L}` using a classical least square solver. + + + Parameters + ---------- + xs : array-like (ns,d) + samples in the source domain + xt : array-like (nt,d) + samples in the target domain + mu : float,optional + Weight for the linear OT loss (>0) + eta : float, optional + Regularization term for the linear mapping L (>0) + bias : bool,optional + Estimate linear mapping with constant bias + numItermax : int, optional + Max number of BCD iterations + stopThr : float, optional + Stop threshold on relative loss decrease (>0) + numInnerItermax : int, optional + Max number of iterations (inner CG solver) + stopInnerThr : float, optional + Stop threshold on error (inner CG solver) (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns, nt) array-like + Optimal transportation matrix for the given parameters + L : (d, d) array-like + Linear mapping matrix ((:math:`d+1`, `d`) if bias) + log : dict + log dictionary return only if log==True in parameters + + + .. _references-joint-OT-mapping-linear: + References + ---------- + .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, + "Mapping estimation for discrete optimal transport", + Neural Information Processing Systems (NIPS), 2016. + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) + + ns, nt, d = xs.shape[0], xt.shape[0], xt.shape[1] + + if bias: + xs1 = nx.concatenate((xs, nx.ones((ns, 1), type_as=xs)), axis=1) + xstxs = nx.dot(xs1.T, xs1) + Id = nx.eye(d + 1, type_as=xs) + Id[-1] = 0 + I0 = Id[:, :-1] + + def sel(x): + return x[:-1, :] + else: + xs1 = xs + xstxs = nx.dot(xs1.T, xs1) + Id = nx.eye(d, type_as=xs) + I0 = Id + + def sel(x): + return x + + if log: + log = {'err': []} + + a = unif(ns, type_as=xs) + b = unif(nt, type_as=xt) + M = dist(xs, xt) * ns + G = emd(a, b, M) + + vloss = [] + + def loss(L, G): + """Compute full loss""" + return ( + nx.sum((nx.dot(xs1, L) - ns * nx.dot(G, xt)) ** 2) + + mu * nx.sum(G * M) + + eta * nx.sum(sel(L - I0) ** 2) + ) + + def solve_L(G): + """ solve L problem with fixed G (least square)""" + xst = ns * nx.dot(G, xt) + return nx.solve(xstxs + eta * Id, nx.dot(xs1.T, xst) + eta * I0) + + def solve_G(L, G0): + """Update G with CG algorithm""" + xsi = nx.dot(xs1, L) + + def f(G): + return nx.sum((xsi - ns * nx.dot(G, xt)) ** 2) + + def df(G): + return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T) + + G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, + numItermax=numInnerItermax, stopThr=stopInnerThr) + return G + + L = solve_L(G) + + vloss.append(loss(L, G)) + + if verbose: + print('{:5s}|{:12s}|{:8s}'.format( + 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) + print('{:5d}|{:8e}|{:8e}'.format(0, vloss[-1], 0)) + + # init loop + if numItermax > 0: + loop = 1 + else: + loop = 0 + it = 0 + + while loop: + + it += 1 + + # update G + G = solve_G(L, G) + + # update L + L = solve_L(G) + + vloss.append(loss(L, G)) + + if it >= numItermax: + loop = 0 + + if abs(vloss[-1] - vloss[-2]) / abs(vloss[-2]) < stopThr: + loop = 0 + + if verbose: + if it % 20 == 0: + print('{:5s}|{:12s}|{:8s}'.format( + 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) + print('{:5d}|{:8e}|{:8e}'.format( + it, vloss[-1], (vloss[-1] - vloss[-2]) / abs(vloss[-2]))) + if log: + log['loss'] = vloss + return G, L, log + else: + return G, L + + +def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', + sigma=1, bias=False, verbose=False, verbose2=False, + numItermax=100, numInnerItermax=10, + stopInnerThr=1e-6, stopThr=1e-5, log=False, + **kwargs): + r"""Joint OT and nonlinear mapping estimation with kernels as proposed in + :ref:`[8] `. + + The function solves the following optimization problem: + + .. math:: + \min_{\gamma, L\in\mathcal{H}}\quad \|L(\mathbf{X_s}) - + n_s\gamma \mathbf{X_t}\|^2_F + \mu \langle \gamma, \mathbf{M} \rangle_F + + \eta \|L\|^2_\mathcal{H} + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 + + + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) squared euclidean cost matrix between samples in + :math:`\mathbf{X_s}` and :math:`\mathbf{X_t}` (scaled by :math:`n_s`) + - :math:`L` is a :math:`n_s \times d` linear operator on a kernel matrix that + approximates the barycentric mapping + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are uniform source and target weights + + The problem consist in solving jointly an optimal transport matrix + :math:`\gamma` and the nonlinear mapping that fits the barycentric mapping + :math:`n_s\gamma \mathbf{X_t}`. + + One can also estimate a mapping with constant bias (see supplementary + material of :ref:`[8] `) using the bias optional argument. + + The algorithm used for solving the problem is the block coordinate + descent that alternates between updates of :math:`\mathbf{G}` (using conditional gradient) + and the update of :math:`\mathbf{L}` using a classical kernel least square solver. + + + Parameters + ---------- + xs : array-like (ns,d) + samples in the source domain + xt : array-like (nt,d) + samples in the target domain + mu : float,optional + Weight for the linear OT loss (>0) + eta : float, optional + Regularization term for the linear mapping L (>0) + kerneltype : str,optional + kernel used by calling function :py:func:`ot.utils.kernel` (gaussian by default) + sigma : float, optional + Gaussian kernel bandwidth. + bias : bool,optional + Estimate linear mapping with constant bias + verbose : bool, optional + Print information along iterations + verbose2 : bool, optional + Print information along iterations + numItermax : int, optional + Max number of BCD iterations + numInnerItermax : int, optional + Max number of iterations (inner CG solver) + stopInnerThr : float, optional + Stop threshold on error (inner CG solver) (>0) + stopThr : float, optional + Stop threshold on relative loss decrease (>0) + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns, nt) array-like + Optimal transportation matrix for the given parameters + L : (ns, d) array-like + Nonlinear mapping matrix ((:math:`n_s+1`, `d`) if bias) + log : dict + log dictionary return only if log==True in parameters + + + .. _references-joint-OT-mapping-kernel: + References + ---------- + .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, + "Mapping estimation for discrete optimal transport", + Neural Information Processing Systems (NIPS), 2016. + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) + + ns, nt = xs.shape[0], xt.shape[0] + + K = kernel(xs, xs, method=kerneltype, sigma=sigma) + if bias: + K1 = nx.concatenate((K, nx.ones((ns, 1), type_as=xs)), axis=1) + Id = nx.eye(ns + 1, type_as=xs) + Id[-1] = 0 + Kp = nx.eye(ns + 1, type_as=xs) + Kp[:ns, :ns] = K + + # ls regu + # K0 = K1.T.dot(K1)+eta*I + # Kreg=I + + # RKHS regul + K0 = nx.dot(K1.T, K1) + eta * Kp + Kreg = Kp + + else: + K1 = K + Id = nx.eye(ns, type_as=xs) + + # ls regul + # K0 = K1.T.dot(K1)+eta*I + # Kreg=I + + # proper kernel ridge + K0 = K + eta * Id + Kreg = K + + if log: + log = {'err': []} + + a = unif(ns, type_as=xs) + b = unif(nt, type_as=xt) + M = dist(xs, xt) * ns + G = emd(a, b, M) + + vloss = [] + + def loss(L, G): + """Compute full loss""" + return ( + nx.sum((nx.dot(K1, L) - ns * nx.dot(G, xt)) ** 2) + + mu * nx.sum(G * M) + + eta * nx.trace(dots(L.T, Kreg, L)) + ) + + def solve_L_nobias(G): + """ solve L problem with fixed G (least square)""" + xst = ns * nx.dot(G, xt) + return nx.solve(K0, xst) + + def solve_L_bias(G): + """ solve L problem with fixed G (least square)""" + xst = ns * nx.dot(G, xt) + return nx.solve(K0, nx.dot(K1.T, xst)) + + def solve_G(L, G0): + """Update G with CG algorithm""" + xsi = nx.dot(K1, L) + + def f(G): + return nx.sum((xsi - ns * nx.dot(G, xt)) ** 2) + + def df(G): + return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T) + + G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, + numItermax=numInnerItermax, stopThr=stopInnerThr) + return G + + if bias: + solve_L = solve_L_bias + else: + solve_L = solve_L_nobias + + L = solve_L(G) + + vloss.append(loss(L, G)) + + if verbose: + print('{:5s}|{:12s}|{:8s}'.format( + 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) + print('{:5d}|{:8e}|{:8e}'.format(0, vloss[-1], 0)) + + # init loop + if numItermax > 0: + loop = 1 + else: + loop = 0 + it = 0 + + while loop: + + it += 1 + + # update G + G = solve_G(L, G) + + # update L + L = solve_L(G) + + vloss.append(loss(L, G)) + + if it >= numItermax: + loop = 0 + + if abs(vloss[-1] - vloss[-2]) / abs(vloss[-2]) < stopThr: + loop = 0 + + if verbose: + if it % 20 == 0: + print('{:5s}|{:12s}|{:8s}'.format( + 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) + print('{:5d}|{:8e}|{:8e}'.format( + it, vloss[-1], (vloss[-1] - vloss[-2]) / abs(vloss[-2]))) + if log: + log['loss'] = vloss + return G, L, log + else: + return G, L diff --git a/test/test_da.py b/test/test_da.py index 843a47d2d..8f248c484 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -862,12 +862,6 @@ def test_emd_laplace(nx): nx.to_numpy(nx.sum(coupling, axis=1)), unif(ns), rtol=1e-3, atol=1e-3) -def test_joint_OT_mapping_verbose(): - xs = np.zeros((2, 1)) - ot.da.joint_OT_mapping_kernel(xs, xs, verbose=True) - ot.da.joint_OT_mapping_linear(xs, xs, verbose=True) - - @pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_sinkhorn_l1l2_gl_cost_vectorized(nx): diff --git a/test/test_mapping.py b/test/test_mapping.py index 6846a25ba..f9be0894d 100644 --- a/test/test_mapping.py +++ b/test/test_mapping.py @@ -45,3 +45,9 @@ def test_brenier_potential_predict_bounds(nx): np.testing.assert_almost_equal(to_numpy(G_lu[1]), to_numpy(X)) # test with no log but classes ot.nearest_brenier_potential_predict_bounds(X, phi, G, X, X_classes=nx.ones(2), Y_classes=nx.ones(2)) + + +def test_joint_OT_mapping_verbose(): + xs = np.zeros((2, 1)) + ot.mapping.joint_OT_mapping_kernel(xs, xs, verbose=True) + ot.mapping.joint_OT_mapping_linear(xs, xs, verbose=True) From 8eb1542a0b1b45c9249009c6d16ebc4cc6635017 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 27 Sep 2023 16:31:08 +0200 Subject: [PATCH 18/21] better ssnb example + ssnb initilisation + small joint_ot_mapping tests --- examples/others/plot_SSNB.py | 73 ++++++++++++++++++++++++++++-------- ot/mapping.py | 24 +++++------- test/test_mapping.py | 16 +++++--- 3 files changed, 77 insertions(+), 36 deletions(-) diff --git a/examples/others/plot_SSNB.py b/examples/others/plot_SSNB.py index 249343b93..8e9e25e4f 100644 --- a/examples/others/plot_SSNB.py +++ b/examples/others/plot_SSNB.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# sphinx_gallery_thumbnail_number = 2 +# sphinx_gallery_thumbnail_number = 4 r""" ===================================================== Smooth and Strongly Convex Nearest Brenier Potentials @@ -10,10 +10,10 @@ :math:`\nabla \varphi \# \mu \approx \nu`. This regularity can be enforced only on the components of a partition of the ambient space, which is a relaxation compared to imposing global regularity. -In this example, we consider a source measure :math:`\mu_s` which is the uniform measure on the unit sphere in +In this example, we consider a source measure :math:`\mu_s` which is the uniform measure on the unit square in :math:`\mathbb{R}^2`, and the target measure :math:`\mu_t` which is the image of :math:`\mu_x` by -:math:`T(x_1, x_2) = (x_1 + 2\mathrm{sign}(x_2), x_2)`. The map :math:`T` is non-smooth, and we wish to approximate it -using a "Brenier-style" map :math:`\nabla \varphi` which is regular on the partition +:math:`T(x_1, x_2) = (x_1 + 2\mathrm{sign}(x_2), 2 * x_2)`. The map :math:`T` is non-smooth, and we wish to approximate +it using a "Brenier-style" map :math:`\nabla \varphi` which is regular on the partition :math:`\lbrace x_1 <=0, x_1>0\rbrace`, which is well adapted to this particular dataset. We represent the gradients of the "bounding potentials" :math:`\varphi_l, \varphi_u` (from [59], Theorem 3.14), @@ -45,36 +45,75 @@ import matplotlib.pyplot as plt import numpy as np import ot +import os # %% # Generating the fitting data -n_fitting_samples = 16 -t = np.linspace(0, 2 * np.pi, n_fitting_samples) -r = 1 -Xs = np.stack([r * np.cos(t), r * np.sin(t)], axis=-1) +n_fitting_samples = 30 +rng = np.random.RandomState(seed=0) +Xs = rng.uniform(-1, 1, size=(n_fitting_samples, 2)) Xs_classes = (Xs[:, 0] < 0).astype(int) -Xt = np.stack([Xs[:, 0] + 2 * np.sign(Xs[:, 0]), Xs[:, 1]], axis=-1) +Xt = np.stack([Xs[:, 0] + 2 * np.sign(Xs[:, 0]), 2 * Xs[:, 1]], axis=-1) plt.scatter(Xs[Xs_classes == 0, 0], Xs[Xs_classes == 0, 1], c='blue', label='source class 0') plt.scatter(Xs[Xs_classes == 1, 0], Xs[Xs_classes == 1, 1], c='dodgerblue', label='source class 1') plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target') +plt.axis('equal') plt.title('Splitting sphere dataset') plt.legend(loc='upper right') plt.show() +# %% +# Plotting image of barycentric projection (SSNB initialisation values) +plt.clf() +pi = ot.emd(ot.unif(n_fitting_samples), ot.unif(n_fitting_samples), ot.dist(Xs, Xt)) +plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source') +plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target') +bar_img = pi @ Xt +for i in range(n_fitting_samples): + plt.plot([Xs[i, 0], bar_img[i, 0]], [Xs[i, 1], bar_img[i, 1]], color='black', alpha=.5) +plt.title('Images of in-data source samples by the barycentric map') +plt.legend(loc='upper right') +plt.axis('equal') +plt.show() + # %% # Fitting the Nearest Brenier Potential -phi, G = ot.nearest_brenier_potential_fit(Xs, Xt, Xs_classes, its=10, seed=0) +L = 3 # need L > 2 to allow the 2*y term, default is 1.4 +if not os.path.isfile('/home/eloi/POT_ssnb/examples/others/phi.npy'): + phi, G = ot.nearest_brenier_potential_fit(Xs, Xt, Xs_classes, its=10, init_method='barycentric', + gradient_lipschitz_constant=L) + np.save('/home/eloi/POT_ssnb/examples/others/phi.npy', phi) + np.save('/home/eloi/POT_ssnb/examples/others/G.npy', G) +else: + phi = np.load('/home/eloi/POT_ssnb/examples/others/phi.npy') + G = np.load('/home/eloi/POT_ssnb/examples/others/G.npy') + +# %% +# Plotting the images of the source data +plt.clf() +plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source') +plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target') +for i in range(n_fitting_samples): + plt.plot([Xs[i, 0], G[i, 0]], [Xs[i, 1], G[i, 1]], color='black', alpha=.5) +plt.title('Images of in-data source samples by the fitted SSNB') +plt.legend(loc='upper right') +plt.axis('equal') +plt.show() # %% # Computing the predictions (images by nabla phi) for random samples of the source distribution -rng = np.random.RandomState(seed=0) -n_predict_samples = 100 -t = rng.uniform(0, 2 * np.pi, size=n_predict_samples) -r = rng.uniform(size=n_predict_samples) -Ys = np.stack([r * np.cos(t), r * np.sin(t)], axis=-1) +n_predict_samples = 50 +Ys = rng.uniform(-1, 1, size=(n_predict_samples, 2)) Ys_classes = (Ys[:, 0] < 0).astype(int) -phi_lu, G_lu = ot.nearest_brenier_potential_predict_bounds(Xs, phi, G, Ys, Xs_classes, Ys_classes) +if not os.path.isfile('/home/eloi/POT_ssnb/examples/others/phi_lu.npy'): + phi_lu, G_lu = ot.nearest_brenier_potential_predict_bounds(Xs, phi, G, Ys, Xs_classes, Ys_classes, + gradient_lipschitz_constant=L) + np.save('/home/eloi/POT_ssnb/examples/others/phi_lu.npy', phi_lu) + np.save('/home/eloi/POT_ssnb/examples/others/G_lu.npy', G_lu) +else: + phi_lu = np.load('/home/eloi/POT_ssnb/examples/others/phi_lu.npy') + G_lu = np.load('/home/eloi/POT_ssnb/examples/others/G_lu.npy') # %% # Plot predictions for the gradient of the lower-bounding potential @@ -85,6 +124,7 @@ plt.plot([Ys[i, 0], G_lu[0, i, 0]], [Ys[i, 1], G_lu[0, i, 1]], color='black', alpha=.5) plt.title('Images of new source samples by $\\nabla \\varphi_l$') plt.legend(loc='upper right') +plt.axis('equal') plt.show() # %% @@ -96,4 +136,5 @@ plt.plot([Ys[i, 0], G_lu[1, i, 0]], [Ys[i, 1], G_lu[1, i, 1]], color='black', alpha=.5) plt.title('Images of new source samples by $\\nabla \\varphi_u$') plt.legend(loc='upper right') +plt.axis('equal') plt.show() diff --git a/ot/mapping.py b/ot/mapping.py index 76bc1f308..fd236393e 100644 --- a/ot/mapping.py +++ b/ot/mapping.py @@ -17,7 +17,7 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly_convex_constant=.6, - gradient_lipschitz_constant=1.4, its=100, log=False, seed=None): + gradient_lipschitz_constant=1.4, its=100, log=False, init_method='barycentric'): r""" Computes optimal values and gradients at X for a strongly convex potential :math:`\varphi` with Lipschitz gradients on the partitions defined by `X_classes`, where :math:`\varphi` is optimal such that @@ -72,8 +72,8 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly number of iterations, defaults to 100 log : bool, optional record log if true - seed: int or RandomState or None, optional - Seed used for random number generator + init_method : str, optional + 'target' initialises G=V, 'barycentric' initialises at the image of X by the barycentric projection Returns ------- @@ -112,19 +112,15 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly assert X_classes.size == n, "incorrect number of class items" else: X_classes = np.zeros(n) - if a is None: - a = unif(n) - if b is None: - b = unif(n) + a = unif(n) if a is None else nx.to_numpy(a) + b = unif(n) if b is None else nx.to_numpy(b) assert a.shape[-1] == b.shape[-1] == n, 'incorrect measure weight sizes' - if isinstance(seed, np.random.RandomState): - G_val = np.random.randn(n, d) - else: - if seed is not None: - np.random.seed(seed) - G_val = np.random.randn(n, d) - + assert init_method in ['target', 'barycentric'], f"Unsupported initialization method '{init_method}'" + if init_method == 'target': + G_val = V + else: # Init G_val with barycentric projection + G_val = emd(a, b, dist(X, V)) @ V phi_val = None log_dict = { 'G_list': [], diff --git a/test/test_mapping.py b/test/test_mapping.py index f9be0894d..6a4839a39 100644 --- a/test/test_mapping.py +++ b/test/test_mapping.py @@ -29,10 +29,9 @@ def test_nearest_brenier_potential_fit(nx): X = nx.ones((2, 2)) phi, G, log = ot.nearest_brenier_potential_fit(X, X, its=3, log=True) np.testing.assert_almost_equal(to_numpy(G), to_numpy(X)) # image of source should be close to target - # test without log but with X_classes and seed - ot.nearest_brenier_potential_fit(X, X, X_classes=nx.ones(2), its=1, seed=0) - # test with seed being a np.random.RandomState - ot.nearest_brenier_potential_fit(X, X, its=1, seed=np.random.RandomState(seed=0)) + # test without log but with X_classes, a, b and other init method + a = nx.ones(2) / 2 + ot.nearest_brenier_potential_fit(X, X, X_classes=nx.ones(2), a=a, b=a, its=1, init_method='target') @pytest.mark.skipif(nocvxpy, reason="No CVXPY available") @@ -47,7 +46,12 @@ def test_brenier_potential_predict_bounds(nx): ot.nearest_brenier_potential_predict_bounds(X, phi, G, X, X_classes=nx.ones(2), Y_classes=nx.ones(2)) -def test_joint_OT_mapping_verbose(): - xs = np.zeros((2, 1)) +def test_joint_OT_mapping(): + """ + Complements the tests in test_da, for verbose, log and bias options + """ + xs = np.array([[.1, .2], [-.1, .3]]) ot.mapping.joint_OT_mapping_kernel(xs, xs, verbose=True) ot.mapping.joint_OT_mapping_linear(xs, xs, verbose=True) + ot.mapping.joint_OT_mapping_kernel(xs, xs, log=True, bias=True) + ot.mapping.joint_OT_mapping_linear(xs, xs, log=True, bias=True) From f29dcff9a7b28ee1fecb5796f99e54d3942e0208 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 27 Sep 2023 16:31:47 +0200 Subject: [PATCH 19/21] better ssnb example + ssnb initilisation + small joint_ot_mapping tests --- examples/others/plot_SSNB.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/examples/others/plot_SSNB.py b/examples/others/plot_SSNB.py index 8e9e25e4f..a37f0a942 100644 --- a/examples/others/plot_SSNB.py +++ b/examples/others/plot_SSNB.py @@ -80,14 +80,8 @@ # %% # Fitting the Nearest Brenier Potential L = 3 # need L > 2 to allow the 2*y term, default is 1.4 -if not os.path.isfile('/home/eloi/POT_ssnb/examples/others/phi.npy'): - phi, G = ot.nearest_brenier_potential_fit(Xs, Xt, Xs_classes, its=10, init_method='barycentric', - gradient_lipschitz_constant=L) - np.save('/home/eloi/POT_ssnb/examples/others/phi.npy', phi) - np.save('/home/eloi/POT_ssnb/examples/others/G.npy', G) -else: - phi = np.load('/home/eloi/POT_ssnb/examples/others/phi.npy') - G = np.load('/home/eloi/POT_ssnb/examples/others/G.npy') +phi, G = ot.nearest_brenier_potential_fit(Xs, Xt, Xs_classes, its=10, init_method='barycentric', + gradient_lipschitz_constant=L) # %% # Plotting the images of the source data @@ -106,14 +100,8 @@ n_predict_samples = 50 Ys = rng.uniform(-1, 1, size=(n_predict_samples, 2)) Ys_classes = (Ys[:, 0] < 0).astype(int) -if not os.path.isfile('/home/eloi/POT_ssnb/examples/others/phi_lu.npy'): - phi_lu, G_lu = ot.nearest_brenier_potential_predict_bounds(Xs, phi, G, Ys, Xs_classes, Ys_classes, - gradient_lipschitz_constant=L) - np.save('/home/eloi/POT_ssnb/examples/others/phi_lu.npy', phi_lu) - np.save('/home/eloi/POT_ssnb/examples/others/G_lu.npy', G_lu) -else: - phi_lu = np.load('/home/eloi/POT_ssnb/examples/others/phi_lu.npy') - G_lu = np.load('/home/eloi/POT_ssnb/examples/others/G_lu.npy') +phi_lu, G_lu = ot.nearest_brenier_potential_predict_bounds(Xs, phi, G, Ys, Xs_classes, Ys_classes, + gradient_lipschitz_constant=L) # %% # Plot predictions for the gradient of the lower-bounding potential From 7a5e6d7de3888df8a066406e50c70000a885a908 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 27 Sep 2023 16:33:18 +0200 Subject: [PATCH 20/21] removed unused dependency in example --- examples/others/plot_SSNB.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/others/plot_SSNB.py b/examples/others/plot_SSNB.py index a37f0a942..9aa43a2f1 100644 --- a/examples/others/plot_SSNB.py +++ b/examples/others/plot_SSNB.py @@ -45,7 +45,6 @@ import matplotlib.pyplot as plt import numpy as np import ot -import os # %% # Generating the fitting data From 55a9b2836ff465c9490a028bc3b7b6d254b76ad9 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 11 Oct 2023 10:54:18 +0200 Subject: [PATCH 21/21] no longer import mapping in __init__ + example thumbnail fix + made qcqp_constants function private --- examples/others/plot_SSNB.py | 11 ++++++----- ot/__init__.py | 6 +----- ot/mapping.py | 10 +++++++--- test/test_mapping.py | 12 ++++++------ 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/examples/others/plot_SSNB.py b/examples/others/plot_SSNB.py index 9aa43a2f1..425105613 100644 --- a/examples/others/plot_SSNB.py +++ b/examples/others/plot_SSNB.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -# sphinx_gallery_thumbnail_number = 4 r""" ===================================================== Smooth and Strongly Convex Nearest Brenier Potentials @@ -42,6 +41,8 @@ # Author: Eloi Tanguy # License: MIT License +# sphinx_gallery_thumbnail_number = 4 + import matplotlib.pyplot as plt import numpy as np import ot @@ -79,8 +80,8 @@ # %% # Fitting the Nearest Brenier Potential L = 3 # need L > 2 to allow the 2*y term, default is 1.4 -phi, G = ot.nearest_brenier_potential_fit(Xs, Xt, Xs_classes, its=10, init_method='barycentric', - gradient_lipschitz_constant=L) +phi, G = ot.mapping.nearest_brenier_potential_fit(Xs, Xt, Xs_classes, its=10, init_method='barycentric', + gradient_lipschitz_constant=L) # %% # Plotting the images of the source data @@ -99,8 +100,8 @@ n_predict_samples = 50 Ys = rng.uniform(-1, 1, size=(n_predict_samples, 2)) Ys_classes = (Ys[:, 0] < 0).astype(int) -phi_lu, G_lu = ot.nearest_brenier_potential_predict_bounds(Xs, phi, G, Ys, Xs_classes, Ys_classes, - gradient_lipschitz_constant=L) +phi_lu, G_lu = ot.mapping.nearest_brenier_potential_predict_bounds(Xs, phi, G, Ys, Xs_classes, Ys_classes, + gradient_lipschitz_constant=L) # %% # Plot predictions for the gradient of the lower-bounding potential diff --git a/ot/__init__.py b/ot/__init__.py index 58445c2cf..44e87eabe 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -35,7 +35,6 @@ from . import factored from . import solvers from . import gaussian -from . import mapping # OT functions from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, @@ -52,8 +51,6 @@ from .weak import weak_optimal_transport from .factored import factored_optimal_transport from .solvers import solve -from .mapping import (nearest_brenier_potential_fit, nearest_brenier_potential_predict_bounds, joint_OT_mapping_kernel, - joint_OT_mapping_linear) # utils functions from .utils import dist, unif, tic, toc, toq @@ -71,5 +68,4 @@ 'factored_optimal_transport', 'solve', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', - 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'nearest_brenier_potential_fit', - 'nearest_brenier_potential_predict_bounds', 'joint_OT_mapping_kernel', 'joint_OT_mapping_linear'] + 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] diff --git a/ot/mapping.py b/ot/mapping.py index fd236393e..cf8b14f2d 100644 --- a/ot/mapping.py +++ b/ot/mapping.py @@ -1,6 +1,10 @@ # -*- coding: utf-8 -*- """ Optimal Transport maps and variants + +.. warning:: + Note that by default the module is not imported in :mod:`ot`. In order to + use it you need to explicitly import :mod:`ot.mapping` """ # Author: Eloi Tanguy @@ -142,7 +146,7 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly for j in range(n): cost += cvx.sum_squares(G[i, :] - V[j, :]) * plan[i, j] objective = cvx.Minimize(cost) # OT cost - c1, c2, c3 = ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant) + c1, c2, c3 = _ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant) for k in np.unique(X_classes): # constraints for the convex interpolation for i in np.where(X_classes == k)[0]: @@ -174,7 +178,7 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly return phi_val, G_val, log_dict -def ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant): +def _ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant): r""" Handy function computing the constants for the Nearest Brenier Potential QCQP problems @@ -309,7 +313,7 @@ def nearest_brenier_potential_predict_bounds(X, phi, G, Y, X_classes=None, Y_cla else: X_classes = np.zeros(n) assert X_classes.size == n, 'wrong number of class items for X' - c1, c2, c3 = ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant) + c1, c2, c3 = _ssnb_qcqp_constants(strongly_convex_constant, gradient_lipschitz_constant) phi_lu = np.zeros((2, m)) G_lu = np.zeros((2, m, d)) log_dict = {} diff --git a/test/test_mapping.py b/test/test_mapping.py index 6a4839a39..991b2374c 100644 --- a/test/test_mapping.py +++ b/test/test_mapping.py @@ -18,7 +18,7 @@ @pytest.mark.skipif(nocvxpy, reason="No CVXPY available") def test_ssnb_qcqp_constants(): - c1, c2, c3 = ot.mapping.ssnb_qcqp_constants(.5, 1) + c1, c2, c3 = ot.mapping._ssnb_qcqp_constants(.5, 1) np.testing.assert_almost_equal(c1, 1) np.testing.assert_almost_equal(c2, .5) np.testing.assert_almost_equal(c3, 1) @@ -27,23 +27,23 @@ def test_ssnb_qcqp_constants(): @pytest.mark.skipif(nocvxpy, reason="No CVXPY available") def test_nearest_brenier_potential_fit(nx): X = nx.ones((2, 2)) - phi, G, log = ot.nearest_brenier_potential_fit(X, X, its=3, log=True) + phi, G, log = ot.mapping.nearest_brenier_potential_fit(X, X, its=3, log=True) np.testing.assert_almost_equal(to_numpy(G), to_numpy(X)) # image of source should be close to target # test without log but with X_classes, a, b and other init method a = nx.ones(2) / 2 - ot.nearest_brenier_potential_fit(X, X, X_classes=nx.ones(2), a=a, b=a, its=1, init_method='target') + ot.mapping.nearest_brenier_potential_fit(X, X, X_classes=nx.ones(2), a=a, b=a, its=1, init_method='target') @pytest.mark.skipif(nocvxpy, reason="No CVXPY available") def test_brenier_potential_predict_bounds(nx): X = nx.ones((2, 2)) - phi, G = ot.nearest_brenier_potential_fit(X, X, its=3) - phi_lu, G_lu, log = ot.nearest_brenier_potential_predict_bounds(X, phi, G, X, log=True) + phi, G = ot.mapping.nearest_brenier_potential_fit(X, X, its=3) + phi_lu, G_lu, log = ot.mapping.nearest_brenier_potential_predict_bounds(X, phi, G, X, log=True) # 'new' input isn't new, so should be equal to target np.testing.assert_almost_equal(to_numpy(G_lu[0]), to_numpy(X)) np.testing.assert_almost_equal(to_numpy(G_lu[1]), to_numpy(X)) # test with no log but classes - ot.nearest_brenier_potential_predict_bounds(X, phi, G, X, X_classes=nx.ones(2), Y_classes=nx.ones(2)) + ot.mapping.nearest_brenier_potential_predict_bounds(X, phi, G, X, X_classes=nx.ones(2), Y_classes=nx.ones(2)) def test_joint_OT_mapping():