From 8e79b24932d5e3382adf5506d6f05140d1dd2ad5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 6 Jan 2025 02:08:05 +0100 Subject: [PATCH] fix solvers --- test/test_solvers.py | 123 ++++++++++++++++++++++++++----------------- 1 file changed, 75 insertions(+), 48 deletions(-) diff --git a/test/test_solvers.py b/test/test_solvers.py index a0c1d7c43..b1bd097a3 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -432,55 +432,82 @@ def test_solve_gromov_grid(nx, reg, reg_type, unbalanced, unbalanced_type, alpha M = np.random.rand(n_samples_s, n_samples_t) try: - sol0 = ot.solve_gromov( - Ca, - Cb, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - loss=loss, - ) # GW - sol0_fgw = ot.solve_gromov( - Ca, - Cb, - M, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - alpha=alpha, - loss=loss, - ) # FGW - - # solve in backend - ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) + if unbalanced_type == "partial" and unbalanced is None: + ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) + + with pytest.raises(ValueError): + solx = ot.solve_gromov( + Cax, + Cbx, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + loss=loss, + ) # GW + with pytest.raises(ValueError): + solx_fgw = ot.solve_gromov( + Cax, + Cbx, + Mx, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + alpha=alpha, + loss=loss, + ) # FGW - solx = ot.solve_gromov( - Cax, - Cbx, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - loss=loss, - ) # GW - solx_fgw = ot.solve_gromov( - Cax, - Cbx, - Mx, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - alpha=alpha, - loss=loss, - ) # FGW - - solx.value_quad - - assert_allclose_sol(sol0, solx) - assert_allclose_sol(sol0_fgw, solx_fgw) + else: + sol0 = ot.solve_gromov( + Ca, + Cb, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + loss=loss, + ) # GW + sol0_fgw = ot.solve_gromov( + Ca, + Cb, + M, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + alpha=alpha, + loss=loss, + ) # FGW + + # solve in backend + ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) + + solx = ot.solve_gromov( + Cax, + Cbx, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + loss=loss, + ) # GW + solx_fgw = ot.solve_gromov( + Cax, + Cbx, + Mx, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + alpha=alpha, + loss=loss, + ) # FGW + + solx.value_quad + + assert_allclose_sol(sol0, solx) + assert_allclose_sol(sol0_fgw, solx_fgw) except NotImplementedError: pytest.skip("Not implemented")