diff --git a/tests/test_transition_matrix_solver.py b/tests/test_transition_matrix_solver.py index 9ac349e7..179d8893 100644 --- a/tests/test_transition_matrix_solver.py +++ b/tests/test_transition_matrix_solver.py @@ -99,6 +99,36 @@ def test_matrix_fit_predict_not_strict(): np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL) +def test_ridge_matrix_fit_predict(): + X = np.array( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + [9, 10], + [11, 12], + ] + ) + + Y = np.array( + [ + [2, 3], + [4, 5], + [6, 7], + [8, 9], + [10, 11], + [12, 13], + ] + ) + + expected = np.array([[0.479416, 0.520584], [0.455918, 0.544082]]) + + tms = TransitionMatrixSolver(lam=1) + current = tms.fit_predict(X, Y) + np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL) + + def test_matrix_get_prediction_interval(): with pytest.raises(NotImplementedError): tms = TransitionMatrixSolver()