From 1b583b8a3a6ded769c34ca1afd53133a4048807d Mon Sep 17 00:00:00 2001 From: Diane Napolitano Date: Mon, 22 Jan 2024 16:29:22 -0500 Subject: [PATCH] Adding a matrix solver unit test where the matrix needs to be pivoted first and clarifying the use of lambda and L2 regularization --- src/elexsolver/TransitionMatrixSolver.py | 2 +- tests/test_transition_matrix_solver.py | 30 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/elexsolver/TransitionMatrixSolver.py b/src/elexsolver/TransitionMatrixSolver.py index 30d1ef7f..bdd7f988 100644 --- a/src/elexsolver/TransitionMatrixSolver.py +++ b/src/elexsolver/TransitionMatrixSolver.py @@ -16,7 +16,7 @@ class TransitionMatrixSolver(TransitionSolver): def __init__(self, strict: bool = True, lam: float | None = None): """ - `lam` > 0 will enable L2 regularization (Ridge). + `lam` != 0 will enable L2 regularization (Ridge). """ super().__init__() self._strict = strict diff --git a/tests/test_transition_matrix_solver.py b/tests/test_transition_matrix_solver.py index 179d8893..8cf83d89 100644 --- a/tests/test_transition_matrix_solver.py +++ b/tests/test_transition_matrix_solver.py @@ -129,6 +129,36 @@ def test_ridge_matrix_fit_predict(): np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL) +def test_matrix_fit_predict_pivoted(): + X = np.array( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + [9, 10], + [11, 12], + ] + ).T + + Y = np.array( + [ + [2, 3], + [4, 5], + [6, 7], + [8, 9], + [10, 11], + [12, 13], + ] + ).T + + expected = np.array([[0.760428, 0.239572], [0.216642, 0.783358]]) + + tms = TransitionMatrixSolver() + current = tms.fit_predict(X, Y) + np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL) + + def test_matrix_get_prediction_interval(): with pytest.raises(NotImplementedError): tms = TransitionMatrixSolver()