diff --git a/tests/test_optim.py b/tests/test_optim.py index 9018be1..53bb7e5 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -140,31 +140,31 @@ def test_gradopt(): X, Y = data_create(20, 5) alpha = np.random.normal(size=5) alpha2 = graduated_optimisation(alpha, X, Y, 0.1, beta=100) - assert loss_smooth(alpha, X, Y, 0.1, beta=100) > loss_smooth( + assert loss_smooth(alpha, X, Y, 0.1, beta=100) >= loss_smooth( alpha2, X, Y, 0.1, beta=100 ) alpha2 = graduated_optimisation(alpha, X, Y, 0.1, beta=100, lambda1=0.5) - assert loss_smooth(alpha, X, Y, 0.1, beta=100, lambda1=0.5) > loss_smooth( + assert loss_smooth(alpha, X, Y, 0.1, beta=100, lambda1=0.5) >= loss_smooth( alpha2, X, Y, 0.1, beta=100, lambda1=0.5 ) alpha2 = graduated_optimisation(alpha, X, Y, 0.1, beta=100, lambda2=0.5) - assert loss_smooth(alpha, X, Y, 0.1, beta=100, lambda2=0.5) > loss_smooth( + assert loss_smooth(alpha, X, Y, 0.1, beta=100, lambda2=0.5) >= loss_smooth( alpha2, X, Y, 0.1, beta=100, lambda2=0.5 ) # With weight w = np.random.uniform(size=20) alpha2 = graduated_optimisation(alpha, X, Y, 0.1, beta=100, weight=w) - assert loss_smooth(alpha, X, Y, 0.1, beta=100, weight=w) > loss_smooth( + assert loss_smooth(alpha, X, Y, 0.1, beta=100, weight=w) >= loss_smooth( alpha2, X, Y, 0.1, beta=100, weight=w ) alpha2 = graduated_optimisation(alpha, X, Y, 0.1, beta=100, lambda1=0.5, weight=w) - assert loss_smooth(alpha, X, Y, 0.1, beta=100, lambda1=0.5, weight=w) > loss_smooth( - alpha2, X, Y, 0.1, beta=100, lambda1=0.5, weight=w - ) + assert loss_smooth( + alpha, X, Y, 0.1, beta=100, lambda1=0.5, weight=w + ) >= loss_smooth(alpha2, X, Y, 0.1, beta=100, lambda1=0.5, weight=w) alpha2 = graduated_optimisation(alpha, X, Y, 0.1, beta=100, lambda2=0.5, weight=w) - assert loss_smooth(alpha, X, Y, 0.1, beta=100, lambda2=0.5, weight=w) > loss_smooth( - alpha2, X, Y, 0.1, beta=100, lambda2=0.5, weight=w - ) + assert loss_smooth( + alpha, X, Y, 0.1, beta=100, lambda2=0.5, weight=w + ) >= loss_smooth(alpha2, X, Y, 0.1, beta=100, lambda2=0.5, weight=w) def test_regres():