diff --git a/test/test_solvers.py b/test/test_solvers.py index a0c1d7c43..b1bd097a3 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -432,55 +432,82 @@ def test_solve_gromov_grid(nx, reg, reg_type, unbalanced, unbalanced_type, alpha M = np.random.rand(n_samples_s, n_samples_t) try: - sol0 = ot.solve_gromov( - Ca, - Cb, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - loss=loss, - ) # GW - sol0_fgw = ot.solve_gromov( - Ca, - Cb, - M, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - alpha=alpha, - loss=loss, - ) # FGW - - # solve in backend - ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) + if unbalanced_type == "partial" and unbalanced is None: + ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) + + with pytest.raises(ValueError): + solx = ot.solve_gromov( + Cax, + Cbx, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + loss=loss, + ) # GW + with pytest.raises(ValueError): + solx_fgw = ot.solve_gromov( + Cax, + Cbx, + Mx, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + alpha=alpha, + loss=loss, + ) # FGW - solx = ot.solve_gromov( - Cax, - Cbx, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - loss=loss, - ) # GW - solx_fgw = ot.solve_gromov( - Cax, - Cbx, - Mx, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - alpha=alpha, - loss=loss, - ) # FGW - - solx.value_quad - - assert_allclose_sol(sol0, solx) - assert_allclose_sol(sol0_fgw, solx_fgw) + else: + sol0 = ot.solve_gromov( + Ca, + Cb, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + loss=loss, + ) # GW + sol0_fgw = ot.solve_gromov( + Ca, + Cb, + M, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + alpha=alpha, + loss=loss, + ) # FGW + + # solve in backend + ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) + + solx = ot.solve_gromov( + Cax, + Cbx, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + loss=loss, + ) # GW + solx_fgw = ot.solve_gromov( + Cax, + Cbx, + Mx, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + alpha=alpha, + loss=loss, + ) # FGW + + solx.value_quad + + assert_allclose_sol(sol0, solx) + assert_allclose_sol(sol0_fgw, solx_fgw) except NotImplementedError: pytest.skip("Not implemented")