Skip to content

Commit

Permalink
fix solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
cedricvincentcuaz committed Jan 6, 2025
1 parent e5c4711 commit 8e79b24
Showing 1 changed file with 75 additions and 48 deletions.
123 changes: 75 additions & 48 deletions test/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,55 +432,82 @@ def test_solve_gromov_grid(nx, reg, reg_type, unbalanced, unbalanced_type, alpha
M = np.random.rand(n_samples_s, n_samples_t)

try:
sol0 = ot.solve_gromov(
Ca,
Cb,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
loss=loss,
) # GW
sol0_fgw = ot.solve_gromov(
Ca,
Cb,
M,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
alpha=alpha,
loss=loss,
) # FGW

# solve in backend
ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb)
if unbalanced_type == "partial" and unbalanced is None:
ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb)

with pytest.raises(ValueError):
solx = ot.solve_gromov(
Cax,
Cbx,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
loss=loss,
) # GW
with pytest.raises(ValueError):
solx_fgw = ot.solve_gromov(
Cax,
Cbx,
Mx,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
alpha=alpha,
loss=loss,
) # FGW

solx = ot.solve_gromov(
Cax,
Cbx,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
loss=loss,
) # GW
solx_fgw = ot.solve_gromov(
Cax,
Cbx,
Mx,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
alpha=alpha,
loss=loss,
) # FGW

solx.value_quad

assert_allclose_sol(sol0, solx)
assert_allclose_sol(sol0_fgw, solx_fgw)
else:
sol0 = ot.solve_gromov(
Ca,
Cb,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
loss=loss,
) # GW
sol0_fgw = ot.solve_gromov(
Ca,
Cb,
M,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
alpha=alpha,
loss=loss,
) # FGW

# solve in backend
ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb)

solx = ot.solve_gromov(
Cax,
Cbx,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
loss=loss,
) # GW
solx_fgw = ot.solve_gromov(
Cax,
Cbx,
Mx,
reg=reg,
reg_type=reg_type,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
alpha=alpha,
loss=loss,
) # FGW

solx.value_quad

assert_allclose_sol(sol0, solx)
assert_allclose_sol(sol0_fgw, solx_fgw)

except NotImplementedError:
pytest.skip("Not implemented")
Expand Down

0 comments on commit 8e79b24

Please sign in to comment.