Skip to content

Commit

Permalink
Adding unit test for matrix solver with L2 regularization
Browse files Browse the repository at this point in the history
  • Loading branch information
dmnapolitano committed Jan 22, 2024
1 parent 4579678 commit bfd05a8
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tests/test_transition_matrix_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,36 @@ def test_matrix_fit_predict_not_strict():
np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL)


def test_ridge_matrix_fit_predict():
X = np.array(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10],
[11, 12],
]
)

Y = np.array(
[
[2, 3],
[4, 5],
[6, 7],
[8, 9],
[10, 11],
[12, 13],
]
)

expected = np.array([[0.479416, 0.520584], [0.455918, 0.544082]])

tms = TransitionMatrixSolver(lam=1)
current = tms.fit_predict(X, Y)
np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL)


def test_matrix_get_prediction_interval():
with pytest.raises(NotImplementedError):
tms = TransitionMatrixSolver()
Expand Down

0 comments on commit bfd05a8

Please sign in to comment.