Skip to content

Commit

Permalink
add kl_loss to all semi-relaxed (f)gw solvers (#559)
Browse files Browse the repository at this point in the history
  • Loading branch information
cedricvincentcuaz authored Nov 4, 2023
1 parent a73ad08 commit 1071759
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 210 deletions.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
+ Upgraded unbalanced OT solvers for more flexibility (PR #539)
+ Add LazyTensor for modeling plans and low rank tensor in large scale OT (PR #544)
+ Add exact line-search for `gromov_wasserstein` and `fused_gromov_wasserstein` with KL loss (PR #556)
+ Add KL loss to all semi-relaxed (Fused) Gromov-Wasserstein solvers (PR #559)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
Expand Down
80 changes: 41 additions & 39 deletions ot/gromov/_semirelaxed.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme
If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
symmetric : bool, optional
Either C1 and C2 are to be assumed symmetric or not.
If let to its default None value, a symmetry test will be conducted.
Expand Down Expand Up @@ -92,8 +91,6 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
International Conference on Learning Representations (ICLR), 2022.
"""
if loss_fun == 'kl_loss':
raise NotImplementedError()
arr = [C1, C2]
if p is not None:
arr.append(list_to_array(p))
Expand Down Expand Up @@ -139,7 +136,7 @@ def df(G):
return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx))

def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, M=0., reg=1., nx=nx, **kwargs)
return solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, ones_p, M=0., reg=1., fC2t=fC2t, nx=nx, **kwargs)

if log:
res, log = semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
Expand Down Expand Up @@ -190,7 +187,6 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm
If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
symmetric : bool, optional
Either C1 and C2 are to be assumed symmetric or not.
If let to its default None value, a symmetry test will be conducted.
Expand Down Expand Up @@ -243,7 +239,12 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm
if loss_fun == 'square_loss':
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2))

elif loss_fun == 'kl_loss':
gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T))
gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)

srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2))

if log:
return srgw, log_srgw
Expand Down Expand Up @@ -291,7 +292,6 @@ def semirelaxed_fused_gromov_wasserstein(
If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
symmetric : bool, optional
Either C1 and C2 are to be assumed symmetric or not.
If let to its default None value, a symmetry test will be conducted.
Expand Down Expand Up @@ -332,9 +332,6 @@ def semirelaxed_fused_gromov_wasserstein(
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
International Conference on Learning Representations (ICLR), 2022.
"""
if loss_fun == 'kl_loss':
raise NotImplementedError()

arr = [M, C1, C2]
if p is not None:
arr.append(list_to_array(p))
Expand Down Expand Up @@ -382,7 +379,7 @@ def df(G):

def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return solve_semirelaxed_gromov_linesearch(
G, deltaG, cost_G, C1, C2, ones_p, M=(1 - alpha) * M, reg=alpha, nx=nx, **kwargs)
G, deltaG, cost_G, hC1, hC2, ones_p, M=(1 - alpha) * M, reg=alpha, fC2t=fC2t, nx=nx, **kwargs)

if log:
res, log = semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
Expand Down Expand Up @@ -434,7 +431,6 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo
If let to its default value None, uniform distribution is taken.
loss_fun : str, optional
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
symmetric : bool, optional
Either C1 and C2 are to be assumed symmetric or not.
If let to its default None value, a symmetry test will be conducted.
Expand Down Expand Up @@ -494,15 +490,20 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo
if loss_fun == 'square_loss':
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
if isinstance(alpha, int) or isinstance(alpha, float):
srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),
(alpha * gC1, alpha * gC2, (1 - alpha) * T))
else:
lin_term = nx.sum(T * M)
srgw_term = (srfgw_dist - (1 - alpha) * lin_term) / alpha
srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M, alpha),
(alpha * gC1, alpha * gC2, (1 - alpha) * T,
srgw_term - lin_term))

elif loss_fun == 'kl_loss':
gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T))
gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)

if isinstance(alpha, int) or isinstance(alpha, float):
srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),
(alpha * gC1, alpha * gC2, (1 - alpha) * T))
else:
lin_term = nx.sum(T * M)
srgw_term = (srfgw_dist - (1 - alpha) * lin_term) / alpha
srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M, alpha),
(alpha * gC1, alpha * gC2, (1 - alpha) * T,
srgw_term - lin_term))

if log:
return srfgw_dist, log_fgw
Expand All @@ -511,7 +512,7 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo


def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,
M, reg, alpha_min=None, alpha_max=None, nx=None, **kwargs):
M, reg, fC2t=None, alpha_min=None, alpha_max=None, nx=None, **kwargs):
"""
Solve the linesearch in the Conditional Gradient iterations for the semi-relaxed Gromov-Wasserstein divergence.
Expand All @@ -524,16 +525,22 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,
Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
cost_G : float
Value of the cost at `G`
C1 : array-like (ns,ns)
Structure matrix in the source domain.
C2 : array-like (nt,nt)
Structure matrix in the target domain.
C1 : array-like (ns,ns), optional
Transformed Structure matrix in the source domain.
Note that for the 'square_loss' and 'kl_loss', we provide hC1 from ot.gromov.init_matrix_semirelaxed
C2 : array-like (nt,nt), optional
Transformed Structure matrix in the source domain.
Note that for the 'square_loss' and 'kl_loss', we provide hC2 from ot.gromov.init_matrix_semirelaxed
ones_p: array-like (ns,1)
Array of ones of size ns
M : array-like (ns,nt)
Cost matrix between the features.
reg : float
Regularization parameter.
fC2t: array-like (nt,nt), optional
Transformed Structure matrix in the source domain.
Note that for the 'square_loss' and 'kl_loss', we provide fC2t from ot.gromov.init_matrix_semirelaxed.
If fC2t is not provided, it is by default fC2t corresponding to the 'square_loss'.
alpha_min : float, optional
Minimum value for alpha
alpha_max : float, optional
Expand Down Expand Up @@ -565,11 +572,14 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,

qG, qdeltaG = nx.sum(G, 0), nx.sum(deltaG, 0)
dot = nx.dot(nx.dot(C1, deltaG), C2.T)
C2t_square = C2.T ** 2
dot_qG = nx.dot(nx.outer(ones_p, qG), C2t_square)
dot_qdeltaG = nx.dot(nx.outer(ones_p, qdeltaG), C2t_square)
a = reg * nx.sum((dot_qdeltaG - 2 * dot) * deltaG)
b = nx.sum(M * deltaG) + reg * (nx.sum((dot_qdeltaG - 2 * dot) * G) + nx.sum((dot_qG - 2 * nx.dot(nx.dot(C1, G), C2.T)) * deltaG))
if fC2t is None:
fC2t = C2.T ** 2
dot_qG = nx.dot(nx.outer(ones_p, qG), fC2t)
dot_qdeltaG = nx.dot(nx.outer(ones_p, qdeltaG), fC2t)

a = reg * nx.sum((dot_qdeltaG - dot) * deltaG)
b = nx.sum(M * deltaG) + reg * (nx.sum((dot_qdeltaG - dot) * G) + nx.sum((dot_qG - nx.dot(nx.dot(C1, G), C2.T)) * deltaG))

alpha = solve_1d_linesearch_quad(a, b)
if alpha_min is not None or alpha_max is not None:
alpha = np.clip(alpha, alpha_min, alpha_max)
Expand Down Expand Up @@ -620,7 +630,6 @@ def entropic_semirelaxed_gromov_wasserstein(
If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
epsilon : float
Regularization term >0
symmetric : bool, optional
Expand Down Expand Up @@ -655,8 +664,6 @@ def entropic_semirelaxed_gromov_wasserstein(
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
International Conference on Learning Representations (ICLR), 2022.
"""
if loss_fun == 'kl_loss':
raise NotImplementedError()
arr = [C1, C2]
if p is not None:
arr.append(list_to_array(p))
Expand Down Expand Up @@ -777,7 +784,6 @@ def entropic_semirelaxed_gromov_wasserstein2(
If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
epsilon : float
Regularization term >0
symmetric : bool, optional
Expand Down Expand Up @@ -869,7 +875,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein(
If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
epsilon : float
Regularization term >0
symmetric : bool, optional
Expand Down Expand Up @@ -907,8 +912,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein(
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
International Conference on Learning Representations (ICLR), 2022.
"""
if loss_fun == 'kl_loss':
raise NotImplementedError()
arr = [M, C1, C2]
if p is not None:
arr.append(list_to_array(p))
Expand Down Expand Up @@ -1032,7 +1035,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein2(
If let to its default value None, uniform distribution is taken.
loss_fun : str, optional
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
epsilon : float
Regularization term >0
symmetric : bool, optional
Expand Down
27 changes: 25 additions & 2 deletions ot/gromov/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,19 @@ def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None):
h_2(b) &= 2b
The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as :
.. math::
L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)
\mathrm{with} \ f_1(a) &= a \log(a) - a
f_2(b) &= b
h_1(a) &= a
h_2(b) &= \log(b)
Parameters
----------
C1 : array-like, shape (ns, ns)
Expand Down Expand Up @@ -451,9 +464,19 @@ def h1(a):
def h2(b):
return 2 * b
elif loss_fun == 'kl_loss':
raise NotImplementedError()
def f1(a):
return a * nx.log(a + 1e-15) - a

def f2(b):
return b

def h1(a):
return a

def h2(b):
return nx.log(b + 1e-15)
else:
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Only 'square_loss' is supported.")
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

constC = nx.dot(nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
nx.ones((1, C2.shape[0]), type_as=p))
Expand Down
Loading

0 comments on commit 1071759

Please sign in to comment.