Skip to content

Commit

Permalink
Adding unit test for strict constraints with matrix solver and global…
Browse files Browse the repository at this point in the history
… RTOL and ATOL constants for all matrix solver unit tests
  • Loading branch information
dmnapolitano committed Jan 22, 2024
1 parent 12861ed commit 4579678
Showing 1 changed file with 39 additions and 6 deletions.
45 changes: 39 additions & 6 deletions tests/test_transition_matrix_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

from elexsolver.TransitionMatrixSolver import BootstrapTransitionMatrixSolver, TransitionMatrixSolver

RTOL = 1e-04
ATOL = 1e-04


def test_matrix_fit_predict():
X = np.array(
Expand Down Expand Up @@ -31,7 +34,7 @@ def test_matrix_fit_predict():

tms = TransitionMatrixSolver()
current = tms.fit_predict(X, Y)
np.testing.assert_allclose(expected, current, rtol=1e-08, atol=1e-02)
np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL)


def test_matrix_fit_predict_with_weights():
Expand Down Expand Up @@ -63,7 +66,37 @@ def test_matrix_fit_predict_with_weights():

tms = TransitionMatrixSolver()
current = tms.fit_predict(X, Y, weights=weights)
np.testing.assert_allclose(expected, current, rtol=1e-08, atol=1e-02)
np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL)


def test_matrix_fit_predict_not_strict():
X = np.array(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10],
[11, 12],
]
)

Y = np.array(
[
[2, 3],
[4, 5],
[6, 7],
[8, 9],
[10, 11],
[12, 13],
]
)

expected = np.array([[0.760451, 0.239558], [0.216624, 0.783369]])

tms = TransitionMatrixSolver(strict=False)
current = tms.fit_predict(X, Y)
np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL)


def test_matrix_get_prediction_interval():
Expand Down Expand Up @@ -99,7 +132,7 @@ def test_bootstrap_fit_predict():

btms = BootstrapTransitionMatrixSolver(B=10, verbose=False)
current = btms.fit_predict(X, Y)
np.testing.assert_allclose(expected, current, rtol=1e-08, atol=1e-02)
np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL)


def test_bootstrap_fit_predict_with_weights():
Expand Down Expand Up @@ -131,7 +164,7 @@ def test_bootstrap_fit_predict_with_weights():

btms = BootstrapTransitionMatrixSolver(B=10, verbose=False)
current = btms.fit_predict(X, Y, weights=weights)
np.testing.assert_allclose(expected, current, rtol=1e-08, atol=1e-02)
np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL)


def test_bootstrap_confidence_interval():
Expand Down Expand Up @@ -163,5 +196,5 @@ def test_bootstrap_confidence_interval():
btms = BootstrapTransitionMatrixSolver(B=10, verbose=False)
_ = btms.fit_predict(X, Y)
(current_lower, current_upper) = btms.get_confidence_interval(0.95)
np.testing.assert_allclose(expected_lower, current_lower, rtol=1e-08, atol=1e-02)
np.testing.assert_allclose(expected_upper, current_upper, rtol=1e-08, atol=1e-02)
np.testing.assert_allclose(expected_lower, current_lower, rtol=RTOL, atol=ATOL)
np.testing.assert_allclose(expected_upper, current_upper, rtol=RTOL, atol=ATOL)

0 comments on commit 4579678

Please sign in to comment.