From 98b9e8168802b66aaba2b1274bb98fab3ebbbfcb Mon Sep 17 00:00:00 2001 From: Diane Napolitano Date: Mon, 22 Jan 2024 16:59:04 -0500 Subject: [PATCH] Two more unit tests on the bootstrap confidence interval --- tests/test_transition_matrix_solver.py | 63 ++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/tests/test_transition_matrix_solver.py b/tests/test_transition_matrix_solver.py index 8cf83d89..60f4b1c2 100644 --- a/tests/test_transition_matrix_solver.py +++ b/tests/test_transition_matrix_solver.py @@ -258,3 +258,66 @@ def test_bootstrap_confidence_interval(): (current_lower, current_upper) = btms.get_confidence_interval(0.95) np.testing.assert_allclose(expected_lower, current_lower, rtol=RTOL, atol=ATOL) np.testing.assert_allclose(expected_upper, current_upper, rtol=RTOL, atol=ATOL) + + +def test_bootstrap_confidence_interval_greater_than_1(): + X = np.array( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + [9, 10], + [11, 12], + ] + ) + + Y = np.array( + [ + [2, 3], + [4, 5], + [6, 7], + [8, 9], + [10, 11], + [12, 13], + ] + ) + + expected_lower = np.array([[0.757573, 0.095978], [0.09128, 0.779471]]) + expected_upper = np.array([[0.904022, 0.242427], [0.220529, 0.90872]]) + + btms = BootstrapTransitionMatrixSolver(B=10, verbose=False) + _ = btms.fit_predict(X, Y) + (current_lower, current_upper) = btms.get_confidence_interval(95) + np.testing.assert_allclose(expected_lower, current_lower, rtol=RTOL, atol=ATOL) + np.testing.assert_allclose(expected_upper, current_upper, rtol=RTOL, atol=ATOL) + + +def test_bootstrap_confidence_interval_invalid(): + X = np.array( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + [9, 10], + [11, 12], + ] + ) + + Y = np.array( + [ + [2, 3], + [4, 5], + [6, 7], + [8, 9], + [10, 11], + [12, 13], + ] + ) + + btms = BootstrapTransitionMatrixSolver(B=10, verbose=False) + _ = btms.fit_predict(X, Y) + + with pytest.raises(ValueError): + btms.get_confidence_interval(-34)