From 45796780c301d9216eb66dbfa41de6f6cdf27ba7 Mon Sep 17 00:00:00 2001 From: Diane Napolitano Date: Mon, 22 Jan 2024 15:55:33 -0500 Subject: [PATCH] Adding unit test for strict constraints with matrix solver and global RTOL and ATOL constants for all matrix solver unit tests --- tests/test_transition_matrix_solver.py | 45 ++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/tests/test_transition_matrix_solver.py b/tests/test_transition_matrix_solver.py index 62129412..9ac349e7 100644 --- a/tests/test_transition_matrix_solver.py +++ b/tests/test_transition_matrix_solver.py @@ -3,6 +3,9 @@ from elexsolver.TransitionMatrixSolver import BootstrapTransitionMatrixSolver, TransitionMatrixSolver +RTOL = 1e-04 +ATOL = 1e-04 + def test_matrix_fit_predict(): X = np.array( @@ -31,7 +34,7 @@ def test_matrix_fit_predict(): tms = TransitionMatrixSolver() current = tms.fit_predict(X, Y) - np.testing.assert_allclose(expected, current, rtol=1e-08, atol=1e-02) + np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL) def test_matrix_fit_predict_with_weights(): @@ -63,7 +66,37 @@ def test_matrix_fit_predict_with_weights(): tms = TransitionMatrixSolver() current = tms.fit_predict(X, Y, weights=weights) - np.testing.assert_allclose(expected, current, rtol=1e-08, atol=1e-02) + np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL) + + +def test_matrix_fit_predict_not_strict(): + X = np.array( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + [9, 10], + [11, 12], + ] + ) + + Y = np.array( + [ + [2, 3], + [4, 5], + [6, 7], + [8, 9], + [10, 11], + [12, 13], + ] + ) + + expected = np.array([[0.760451, 0.239558], [0.216624, 0.783369]]) + + tms = TransitionMatrixSolver(strict=False) + current = tms.fit_predict(X, Y) + np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL) def test_matrix_get_prediction_interval(): @@ -99,7 +132,7 @@ def test_bootstrap_fit_predict(): btms = BootstrapTransitionMatrixSolver(B=10, verbose=False) current = btms.fit_predict(X, Y) - np.testing.assert_allclose(expected, current, rtol=1e-08, atol=1e-02) + np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL) def test_bootstrap_fit_predict_with_weights(): @@ -131,7 +164,7 @@ def test_bootstrap_fit_predict_with_weights(): btms = BootstrapTransitionMatrixSolver(B=10, verbose=False) current = btms.fit_predict(X, Y, weights=weights) - np.testing.assert_allclose(expected, current, rtol=1e-08, atol=1e-02) + np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL) def test_bootstrap_confidence_interval(): @@ -163,5 +196,5 @@ def test_bootstrap_confidence_interval(): btms = BootstrapTransitionMatrixSolver(B=10, verbose=False) _ = btms.fit_predict(X, Y) (current_lower, current_upper) = btms.get_confidence_interval(0.95) - np.testing.assert_allclose(expected_lower, current_lower, rtol=1e-08, atol=1e-02) - np.testing.assert_allclose(expected_upper, current_upper, rtol=1e-08, atol=1e-02) + np.testing.assert_allclose(expected_lower, current_lower, rtol=RTOL, atol=ATOL) + np.testing.assert_allclose(expected_upper, current_upper, rtol=RTOL, atol=ATOL)