Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] New API ot.solve_sample #558

Closed
wants to merge 13 commits into from
5 changes: 3 additions & 2 deletions ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from . import solvers
from . import gaussian


# OT functions
from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d,
binary_search_circle, wasserstein_circle,
Expand All @@ -50,7 +51,7 @@
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
from .weak import weak_optimal_transport
from .factored import factored_optimal_transport
from .solvers import solve, solve_gromov
from .solvers import solve, solve_gromov, solve_sample

# utils functions
from .utils import dist, unif, tic, toc, toq
Expand All @@ -65,7 +66,7 @@
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein',
'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport',
'factored_optimal_transport', 'solve', 'solve_gromov',
'factored_optimal_transport', 'solve', 'solve_gromov','solve_sample',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
'binary_search_circle', 'wasserstein_circle',
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif']
146 changes: 144 additions & 2 deletions ot/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
#
# License: MIT License

from .utils import OTResult
from .utils import OTResult, unif, dist
from .lp import emd2
from .backend import get_backend
from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced
from .bregman import sinkhorn_log
from .bregman import sinkhorn_log, empirical_sinkhorn
from .partial import partial_wasserstein_lagrange
from .smooth import smooth_ot_dual
from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2,
Expand All @@ -21,6 +21,9 @@
entropic_semirelaxed_gromov_wasserstein2)
from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2




#, entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2


Expand Down Expand Up @@ -848,3 +851,142 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None,
value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx)

return res




##### new ot.solve_sample function

def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None,
unbalanced_type='KL', is_Lazy=False, batch_size=None, n_threads=1, max_iter=None, plan_init=None,
potentials_init=None, tol=None, verbose=False):

r"""Solve the discrete optimal transport problem using the samples in the source and target domains.

The function solves the following general optimal transport problem

.. math::
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) +
\lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) +
\lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})

The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By
default ``reg=None`` and there is no regularization. The unbalanced marginal
penalization can be selected with `unbalanced` (:math:`\lambda_u`) and
`unbalanced_type`. By default ``unbalanced=None`` and the function
solves the exact optimal transport problem (respecting the marginals).

Parameters
----------
X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
a : array-like, shape (dim_a,), optional
Samples weights in the source domain (default is uniform)
b : array-like, shape (dim_b,), optional
Samples weights in the source domain (default is uniform)
reg : float, optional
Regularization weight :math:`\lambda_r`, by default None (no reg., exact
OT)
reg_type : str, optional
Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL"
unbalanced : float, optional
Unbalanced penalization weight :math:`\lambda_u`, by default None
(balanced OT)
unbalanced_type : str, optional
Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL"
is_Lazy : bool, optional
Return :any:`OTResultlazy` object to reduce memory cost when True, by default False
n_threads : int, optional
Number of OMP threads for exact OT solver, by default 1
max_iter : int, optional
Maximum number of iteration, by default None (default values in each solvers)
plan_init : array_like, shape (dim_a, dim_b), optional
Initialization of the OT plan for iterative methods, by default None
potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional
Initialization of the OT dual potentials for iterative methods, by default None
tol : _type_, optional
Tolerance for solution precision, by default None (default values in each solvers)
verbose : bool, optional
Print information in the solver, by default False

Returns
-------

res : OTResult()
Result of the optimization problem. The information can be obtained as follows:

- res.plan : OT plan :math:`\mathbf{T}`
- res.potentials : OT dual potentials
- res.value : Optimal value of the optimization problem
- res.value_linear : Linear OT loss with the optimal OT plan

See :any:`OTResult` for more information.

"""

# Detect backend
arr = [X_s,X_t]
if a is not None:
arr.append(a)
if b is not None:
arr.append(b)
nx = get_backend(*arr)

# create uniform weights if not given
ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
a = nx.from_numpy(unif(ns), type_as=X_s)
if b is None:
b = nx.from_numpy(unif(nt), type_as=X_s)

if metric != 'sqeuclidean':
raise (NotImplementedError('Not implemented metric = {} (only sqeulidean)'.format(metric)))


# default values for solutions
potentials = None
lazy_plan = None

if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
if batch_size is None:
batch_size = 100

if is_Lazy:
################# WIP ####################
if reg is None or reg == 0: # EMD solver for isLazy ?

if unbalanced is None: # balanced EMD solver for isLazy ?
raise (NotImplementedError('Not implemented balanced with no regularization'))

else:
raise (NotImplementedError('Not implemented unbalanced_type="{}" with no regularization'.format(unbalanced_type)))


#############################################

else:
if unbalanced is None:
u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol,
isLazy=True, batchSize=batch_size, verbose=verbose, log=True)
# compute potentials
potentials = (log["u"], log["v"])

# compute lazy_plan
# ...

raise (NotImplementedError('Not implemented balanced with regularization'))

else:
raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type)))

else:
# compute cost matrix M and use solve function
M = dist(X_s, X_t, metric)

res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose)
return res
3 changes: 1 addition & 2 deletions ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,6 @@ def citation(self):
}
"""


class LazyTensor(object):
""" A lazy tensor is a tensor that is not stored in memory. Instead, it is
defined by a function that computes its values on the fly from slices.
Expand Down Expand Up @@ -1233,4 +1232,4 @@ def __getitem__(self, key):
return self._getitem(*k, **self.kwargs)

def __repr__(self):
return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys()))
return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys()))
90 changes: 90 additions & 0 deletions test/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,93 @@ def test_solve_gromov_not_implemented(nx):
ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=1.5)
with pytest.raises(NotImplementedError):
ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=0.5, symmetric=False)




######## Test functions for ot.solve_sample ########


@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type))
def test_solve_sample(nx):
# test solve_sample when is_Lazy = False
n = 100
X_s = np.reshape(1.0 * np.arange(n), (n, 1))
X_t = np.reshape(1.0 * np.arange(0, n), (n, 1))

a = ot.utils.unif(X_s.shape[0])
b = ot.utils.unif(X_t.shape[0])

# solve unif weights
sol0 = ot.solve_sample(X_s, X_t)

# solve signe weights
sol = ot.solve_sample(X_s, X_t, a, b)

# check some attributes
sol.potentials
sol.sparse_plan
sol.marginals
sol.status

assert_allclose_sol(sol0, sol)

# solve in backend
X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b)
solb = ot.solve_sample(X_sb, X_tb, ab, bb)

assert_allclose_sol(sol, solb)

# test not implemented unbalanced and check raise
with pytest.raises(NotImplementedError):
sol0 = ot.solve_sample(X_s, X_t, unbalanced=1, unbalanced_type='cryptic divergence')

# test not implemented reg_type and check raise
with pytest.raises(NotImplementedError):
sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence')



def test_lazy_solve_sample(nx):
# test solve_sample when is_Lazy = True
n = 100
X_s = np.reshape(1.0 * np.arange(n), (n, 1))
X_t = np.reshape(1.0 * np.arange(0, n), (n, 1))

a = ot.utils.unif(X_s.shape[0])
b = ot.utils.unif(X_t.shape[0])

# solve unif weights
sol0 = ot.solve_sample(X_s, X_t, reg=0.1, is_Lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True

# solve signe weights
sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, is_Lazy=True)

# check some attributes
sol.potentials
sol.lazy_plan

assert_allclose_sol(sol0, sol)

# solve in backend
X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b)
solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, is_Lazy=True)

assert_allclose_sol(sol, solb)

# test not implemented reg==0 (or None) + balanced and check raise
with pytest.raises(NotImplementedError):
sol0 = ot.solve_sample(X_s, X_t, is_Lazy=True) # reg == 0 (or None) + unbalanced= None are default

# test not implemented reg==0 (or None) + unbalanced_type and check raise
with pytest.raises(NotImplementedError):
sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", is_Lazy=True) # reg == 0 (or None) is default

# test not implemented reg != 0 + unbalanced_type and check raise
with pytest.raises(NotImplementedError):
sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", is_Lazy=True)





1 change: 1 addition & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,3 +569,4 @@ def test_lowrank_LazyTensor(nx):
T = ot.utils.get_lowrank_lazytensor(X1, X2, diag_d, nx=nx)

np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0))

Loading