Skip to content

Commit

Permalink
Adding a matrix solver unit test where the matrix needs to be pivoted…
Browse files Browse the repository at this point in the history
… first and clarifying the use of lambda and L2 regularization
  • Loading branch information
dmnapolitano committed Jan 22, 2024
1 parent bfd05a8 commit 1b583b8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/elexsolver/TransitionMatrixSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class TransitionMatrixSolver(TransitionSolver):
def __init__(self, strict: bool = True, lam: float | None = None):
"""
`lam` > 0 will enable L2 regularization (Ridge).
`lam` != 0 will enable L2 regularization (Ridge).
"""
super().__init__()
self._strict = strict
Expand Down
30 changes: 30 additions & 0 deletions tests/test_transition_matrix_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,36 @@ def test_ridge_matrix_fit_predict():
np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL)


def test_matrix_fit_predict_pivoted():
X = np.array(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10],
[11, 12],
]
).T

Y = np.array(
[
[2, 3],
[4, 5],
[6, 7],
[8, 9],
[10, 11],
[12, 13],
]
).T

expected = np.array([[0.760428, 0.239572], [0.216642, 0.783358]])

tms = TransitionMatrixSolver()
current = tms.fit_predict(X, Y)
np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL)


def test_matrix_get_prediction_interval():
with pytest.raises(NotImplementedError):
tms = TransitionMatrixSolver()
Expand Down

0 comments on commit 1b583b8

Please sign in to comment.